{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Demo Notebook"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Steps:\n",
    "1. Download SAE with SAE Lens.\n",
    "2. Create a dataset consistent with that SAE. \n",
    "3. Fold the SAE decoder norm weights so that feature activations are \"correct\".\n",
    "4. Estimate the activation normalization constant if needed, and fold it into the SAE weights.\n",
    "5. Run the SAE generator for the features you want."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Set Up"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from transformer_lens import HookedTransformer\n",
    "from sae_lens import ActivationsStore, SAE\n",
    "from importlib import reload\n",
    "import sae_dashboard\n",
    "\n",
    "reload(sae_dashboard)\n",
    "\n",
    "if torch.backends.mps.is_available():\n",
    "    device = \"mps\"\n",
    "else:\n",
    "    device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "print(f\"Device: {device}\")\n",
    "\n",
    "model = HookedTransformer.from_pretrained(\n",
    "    \"mistral-7b\", device=device, n_devices=4, dtype=\"bfloat16\"\n",
    ")\n",
    "\n",
    "# the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)\n",
    "# Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict\n",
    "# We also return the feature sparsities which are stored in HF for convenience.\n",
    "sae, cfg_dict, sparsity = SAE.from_pretrained(\n",
    "    release=\"mistral-7b-res-wg\",  # see other options in sae_lens/pretrained_saes.yaml\n",
    "    sae_id=\"blocks.8.hook_resid_pre\",  # won't always be a hook point\n",
    "    device=\"cuda:3\",\n",
    ")\n",
    "# fold w_dec norm so feature activations are accurate\n",
    "sae.fold_W_dec_norm()\n",
    "\n",
    "\n",
    "activations_store = ActivationsStore.from_sae(\n",
    "    model=model,\n",
    "    sae=sae,\n",
    "    streaming=True,\n",
    "    store_batch_size_prompts=8,\n",
    "    n_batches_in_buffer=8,\n",
    "    device=\"cpu\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "\n",
    "def get_tokens(\n",
    "    activations_store: ActivationsStore,\n",
    "    n_batches_to_sample_from: int = 4096 * 6,\n",
    "    n_prompts_to_select: int = 4096 * 6,\n",
    "):\n",
    "    all_tokens_list = []\n",
    "    pbar = tqdm(range(n_batches_to_sample_from))\n",
    "    for _ in pbar:\n",
    "        batch_tokens = activations_store.get_batch_tokens()\n",
    "        batch_tokens = batch_tokens[torch.randperm(batch_tokens.shape[0])][\n",
    "            : batch_tokens.shape[0]\n",
    "        ]\n",
    "        all_tokens_list.append(batch_tokens)\n",
    "\n",
    "    all_tokens = torch.cat(all_tokens_list, dim=0)\n",
    "    all_tokens = all_tokens[torch.randperm(all_tokens.shape[0])]\n",
    "    return all_tokens[:n_prompts_to_select]\n",
    "\n",
    "\n",
    "# 1000 prompts is plenty for a demo.\n",
    "# token_dataset = get_tokens(activations_store)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# torch.save(token_dataset, \"token_dataset.pt\")\n",
    "token_dataset = torch.load(\"token_dataset.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.rmdir(\"demo_activations_cache\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "\n",
    "test_feature_idx_gpt = list(range(256))\n",
    "\n",
    "feature_vis_config_gpt = sae_dashboard.SaeVisConfig(\n",
    "    hook_point=sae.cfg.hook_name,\n",
    "    features=test_feature_idx_gpt,\n",
    "    minibatch_size_features=16,\n",
    "    minibatch_size_tokens=32,  # this is really prompt with the number of tokens determined by the sequence length\n",
    "    verbose=True,\n",
    "    device=\"cuda\",\n",
    "    cache_dir=Path(\n",
    "        \"demo_activations_cache\"\n",
    "    ),  # this will enable us to skip running the model for subsequent features.\n",
    "    dtype=\"bfloat16\",\n",
    ")\n",
    "\n",
    "runner = sae_dashboard.SaeVisRunner(feature_vis_config_gpt)\n",
    "\n",
    "data = runner.run(\n",
    "    encoder=sae,  # type: ignore\n",
    "    model=model,\n",
    "    tokens=token_dataset[:4096],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sae_dashboard.data_writing_fns import save_feature_centric_vis\n",
    "\n",
    "filename = f\"demo_feature_dashboards.html\"\n",
    "save_feature_centric_vis(sae_vis_data=data, filename=filename)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Quick Profiling experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "activations_store = ActivationsStore.from_sae(\n",
    "    model=model,\n",
    "    sae=sae,\n",
    "    streaming=True,\n",
    "    store_batch_size_prompts=32,\n",
    "    n_batches_in_buffer=1,\n",
    "    device=\"cpu\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sae.cfg.d_in"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gc\n",
    "import torch\n",
    "from safetensors.torch import save_file\n",
    "from torch.profiler import profile, record_function, ProfilerActivity\n",
    "\n",
    "gc.collect()\n",
    "torch.cuda.empty_cache()\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def my_function():\n",
    "    # Your PyTorch code here\n",
    "    for _ in range(5):\n",
    "        tokens = token_dataset[:32]\n",
    "        _, cache = model.run_with_cache(\n",
    "            tokens, stop_at_layer=sae.cfg.hook_layer + 1, names_filter=sae.cfg.hook_name\n",
    "        )\n",
    "        sae_in = cache[sae.cfg.hook_name]\n",
    "        # tensors = {\"activations\": sae_in}\n",
    "        # save_file(tensors, \"test.safetensors\")\n",
    "        # del tensors\n",
    "\n",
    "\n",
    "with profile(\n",
    "    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True\n",
    ") as prof:\n",
    "    with record_function(\"my_function\"):\n",
    "        my_function()\n",
    "\n",
    "print(prof.key_averages().table(sort_by=\"cpu_time_total\", row_limit=10))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
