{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Demo Notebook"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Steps:\n",
    "1. Download SAE with SAE Lens.\n",
    "2. Create a dataset consistent with that SAE. \n",
    "3. Fold the SAE decoder norm weights so that feature activations are \"correct\".\n",
    "4. Estimate the activation normalization constant if needed, and fold it into the SAE weights.\n",
    "5. Run the SAE generator for the features you want."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Set Up"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Download Gemma-2-9b weights\n",
    "\n",
    "import wandb\n",
    "\n",
    "run = wandb.init()\n",
    "artifact = run.use_artifact(\n",
    "    \"jbloom/gemma-2-9b_test/sae_gemma-2-9b_blocks.24.hook_resid_post_114688:v7\",\n",
    "    type=\"model\",\n",
    ")\n",
    "artifact_dir = artifact.download()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import wandb\n",
    "\n",
    "run = wandb.init()\n",
    "artifact = run.use_artifact(\n",
    "    \"jbloom/gemma-2-9b_test/sae_gemma-2-9b_blocks.24.hook_resid_post_114688_log_feature_sparsity:v7\",\n",
    "    type=\"log_feature_sparsity\",\n",
    ")\n",
    "artifact_dir = artifact.download()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "from safetensors.torch import load_file\n",
    "\n",
    "# Assume we have a PyTorch tensor\n",
    "feature_sparsity = load_file(\n",
    "    \"artifacts/sae_gemma-2-9b_blocks.24.hook_resid_post_114688:v7/sparsity.safetensors\"\n",
    ")[\"sparsity\"]\n",
    "\n",
    "# Convert the tensor to a numpy array\n",
    "data = feature_sparsity.numpy()\n",
    "\n",
    "# Create the histogram\n",
    "plt.hist(data, bins=30, edgecolor=\"black\")\n",
    "\n",
    "# Add labels and title\n",
    "plt.xlabel(\"Value\")\n",
    "plt.ylabel(\"Frequency\")\n",
    "plt.title(\"Histogram of PyTorch Tensor\")\n",
    "\n",
    "# Show the plot\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from transformer_lens import HookedTransformer\n",
    "from sae_lens import ActivationsStore, SAE\n",
    "from importlib import reload\n",
    "import sae_dashboard\n",
    "\n",
    "torch.set_grad_enabled(False)\n",
    "\n",
    "reload(sae_dashboard)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL = \"gemma-2-9b\"\n",
    "\n",
    "if torch.backends.mps.is_available():\n",
    "    device = \"mps\"\n",
    "else:\n",
    "    device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "print(f\"Device: {device}\")\n",
    "\n",
    "model = HookedTransformer.from_pretrained(MODEL, device=device, dtype=\"bfloat16\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sae = SAE.load_from_pretrained(\n",
    "    \"artifacts/sae_gemma-2-9b_blocks.24.hook_resid_post_114688:v7\"\n",
    ")\n",
    "sae.fold_W_dec_norm()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# _, cache = model.run_with_cache(\"Wasssssup\", names_filter = sae.cfg.hook_name)\n",
    "# sae_in = cache[sae.cfg.hook_name]\n",
    "# print(sae_in.shape)\n",
    "sae_in = torch.rand((1, 4, 3584)).to(sae.device)\n",
    "sae_out = sae(sae_in)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)\n",
    "# # Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict\n",
    "# # We also return the feature sparsities which are stored in HF for convenience.\n",
    "# sae, cfg_dict, sparsity = SAE.from_pretrained(\n",
    "#     release = \"mistral-7b-res-wg\", # see other options in sae_lens/pretrained_saes.yaml\n",
    "#     sae_id = \"blocks.8.hook_resid_pre\", # won't always be a hook point\n",
    "#     device = \"cuda:3\",\n",
    "# )\n",
    "# # fold w_dec norm so feature activations are accurate\n",
    "#\n",
    "activations_store = ActivationsStore.from_sae(\n",
    "    model=model,\n",
    "    sae=sae,\n",
    "    streaming=True,\n",
    "    store_batch_size_prompts=8,\n",
    "    n_batches_in_buffer=8,\n",
    "    device=\"cpu\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sae.encode_fn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sae_lens import run_evals\n",
    "\n",
    "eval_metrics = run_evals(\n",
    "    sae=sae,\n",
    "    activation_store=activations_store,\n",
    "    model=model,\n",
    "    n_eval_batches=3,\n",
    "    eval_batch_size_prompts=8,\n",
    ")\n",
    "\n",
    "# CE Loss score should be high for residual stream SAEs\n",
    "print(eval_metrics[\"metrics/CE_loss_score\"])\n",
    "\n",
    "# ce loss without SAE should be fairly low < 3.5 suggesting the Model is being run correctly\n",
    "print(eval_metrics[\"metrics/ce_loss_without_sae\"])\n",
    "\n",
    "# ce loss with SAE shouldn't be massively higher\n",
    "print(eval_metrics[\"metrics/ce_loss_with_sae\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "\n",
    "def get_tokens(\n",
    "    activations_store: ActivationsStore,\n",
    "    n_batches_to_sample_from: int = 4096,\n",
    "):\n",
    "    all_tokens_list = []\n",
    "    pbar = tqdm(range(n_batches_to_sample_from))\n",
    "    for _ in pbar:\n",
    "        batch_tokens = activations_store.get_batch_tokens()\n",
    "        batch_tokens = batch_tokens[torch.randperm(batch_tokens.shape[0])][\n",
    "            : batch_tokens.shape[0]\n",
    "        ]\n",
    "        all_tokens_list.append(batch_tokens)\n",
    "\n",
    "    all_tokens = torch.cat(all_tokens_list, dim=0)\n",
    "    all_tokens = all_tokens[torch.randperm(all_tokens.shape[0])]\n",
    "    return all_tokens\n",
    "\n",
    "\n",
    "# 1000 prompts is plenty for a demo.\n",
    "token_dataset = get_tokens(activations_store)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# torch.save(token_dataset, \"to\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# torch.save(token_dataset, \"token_dataset.pt\")\n",
    "token_dataset = torch.load(\"token_dataset.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.rmdir(\"demo_activations_cache\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "\n",
    "def select_indices_in_range(tensor, min_val, max_val, num_samples=None):\n",
    "    \"\"\"\n",
    "    Select indices of a tensor where values fall within a specified range.\n",
    "\n",
    "    Args:\n",
    "    tensor (torch.Tensor): Input tensor with values between -10 and 0.\n",
    "    min_val (float): Minimum value of the range (inclusive).\n",
    "    max_val (float): Maximum value of the range (inclusive).\n",
    "    num_samples (int, optional): Number of indices to randomly select. If None, return all indices.\n",
    "\n",
    "    Returns:\n",
    "    torch.Tensor: Tensor of selected indices.\n",
    "    \"\"\"\n",
    "    # Ensure the input range is valid\n",
    "    if not (-10 <= min_val <= max_val <= 0):\n",
    "        raise ValueError(\n",
    "            \"Range must be within -10 to 0, and min_val must be <= max_val\"\n",
    "        )\n",
    "\n",
    "    # Find indices where values are within the specified range\n",
    "    mask = (tensor >= min_val) & (tensor <= max_val)\n",
    "    indices = mask.nonzero().squeeze()\n",
    "\n",
    "    # If num_samples is specified and less than the total number of valid indices,\n",
    "    # randomly select that many indices\n",
    "    if num_samples is not None and num_samples < indices.numel():\n",
    "        perm = torch.randperm(indices.numel())\n",
    "        indices = indices[perm[:num_samples]]\n",
    "\n",
    "    return indices\n",
    "\n",
    "\n",
    "n_features = 4096\n",
    "feature_idxs = select_indices_in_range(feature_sparsity, -4, -2, 4096)\n",
    "feature_sparsity[feature_idxs.tolist()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from importlib import reload\n",
    "import sys\n",
    "\n",
    "\n",
    "def reload_user_modules(module_names):\n",
    "    \"\"\"Reload specified user modules.\"\"\"\n",
    "    for name in module_names:\n",
    "        if name in sys.modules:\n",
    "            reload(sys.modules[name])\n",
    "\n",
    "\n",
    "# List of your module names\n",
    "user_modules = [\n",
    "    \"sae_dashboard\",\n",
    "    \"sae_dashboard.sae_vis_runner\",\n",
    "    \"sae_dashboard.data_parsing_fns\",\n",
    "    \"sae_dashboard.feature_data_generator\",\n",
    "]\n",
    "\n",
    "# Reload modules\n",
    "reload_user_modules(user_modules)\n",
    "\n",
    "# Re-import after reload\n",
    "from sae_dashboard.feature_data_generator import FeatureDataGenerator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "\n",
    "test_feature_idx_gpt = feature_idxs.tolist()\n",
    "\n",
    "feature_vis_config_gpt = sae_vis_runner.SaeVisConfig(\n",
    "    hook_point=sae.cfg.hook_name,\n",
    "    features=test_feature_idx_gpt,\n",
    "    minibatch_size_features=16,\n",
    "    minibatch_size_tokens=4096,  # this is really prompt with the number of tokens determined by the sequence length\n",
    "    verbose=True,\n",
    "    device=\"cuda\",\n",
    "    cache_dir=Path(\n",
    "        \"demo_activations_cache\"\n",
    "    ),  # this will enable us to skip running the model for subsequent features.\n",
    "    dtype=\"bfloat16\",\n",
    ")\n",
    "\n",
    "runner = sae_vis_runner.SaeVisRunner(feature_vis_config_gpt)\n",
    "\n",
    "data = runner.run(\n",
    "    encoder=sae,  # type: ignore\n",
    "    model=model,\n",
    "    tokens=token_dataset[:1024],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sae_dashboard.data_writing_fns import save_feature_centric_vis\n",
    "\n",
    "filename = f\"demo_feature_dashboards.html\"\n",
    "save_feature_centric_vis(sae_vis_data=data, filename=filename)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Quick Profiling experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def mock_feature_acts_subset_for_now(sae: SAE):\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def sae_lens_get_feature_acts_subset(x: torch.Tensor, feature_idx):  # type: ignore\n",
    "        \"\"\"\n",
    "        Get a subset of the feature activations for a dataset.\n",
    "        \"\"\"\n",
    "        original_device = x.device\n",
    "        feature_activations = sae.encode_fn(x.to(device=sae.device, dtype=sae.dtype))\n",
    "        return feature_activations[..., feature_idx].to(original_device)\n",
    "\n",
    "    sae.get_feature_acts_subset = sae_lens_get_feature_acts_subset  # type: ignore\n",
    "\n",
    "    return sae\n",
    "\n",
    "\n",
    "sae = mock_feature_acts_subset_for_now(sae)\n",
    "feature_idxs = list(range(128))\n",
    "sae_in = torch.rand((1, 4, 3584)).to(sae.device)\n",
    "sae.get_feature_acts_subset(sae_in, feature_idxs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for k, v in sae.named_parameters():\n",
    "    print(k, v.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "from typing import List\n",
    "\n",
    "\n",
    "class FeatureMaskingContext:\n",
    "    def __init__(self, sae: SAE, feature_idxs: List):\n",
    "        self.sae = sae\n",
    "        self.feature_idxs = feature_idxs\n",
    "        self.original_weight = {}\n",
    "\n",
    "    def __enter__(self):\n",
    "\n",
    "        ## W_dec\n",
    "        self.original_weight[\"W_dec\"] = getattr(self.sae, \"W_dec\").data.clone()\n",
    "        # mask the weight\n",
    "        masked_weight = sae.W_dec[self.feature_idxs]\n",
    "        # set the weight\n",
    "        setattr(self.sae, \"W_dec\", nn.Parameter(masked_weight))\n",
    "\n",
    "        ## W_enc\n",
    "        # clone the weight.\n",
    "        self.original_weight[\"W_enc\"] = getattr(self.sae, \"W_enc\").data.clone()\n",
    "        # mask the weight\n",
    "        masked_weight = sae.W_enc[:, self.feature_idxs]\n",
    "        # set the weight\n",
    "        setattr(self.sae, \"W_enc\", nn.Parameter(masked_weight))\n",
    "\n",
    "        if self.sae.cfg.architecture == \"standard\":\n",
    "\n",
    "            ## b_enc\n",
    "            self.original_weight[\"b_enc\"] = getattr(self.sae, \"b_enc\").data.clone()\n",
    "            # mask the weight\n",
    "            masked_weight = sae.b_enc[self.feature_idxs]\n",
    "            # set the weight\n",
    "            setattr(self.sae, \"b_enc\", nn.Parameter(masked_weight))\n",
    "\n",
    "        elif self.sae.cfg.architecture == \"gated\":\n",
    "\n",
    "            ## b_gate\n",
    "            self.original_weight[\"b_gate\"] = getattr(self.sae, \"b_gate\").data.clone()\n",
    "            # mask the weight\n",
    "            masked_weight = sae.b_gate[self.feature_idxs]\n",
    "            # set the weight\n",
    "            setattr(self.sae, \"b_gate\", nn.Parameter(masked_weight))\n",
    "\n",
    "            ## r_mag\n",
    "            self.original_weight[\"r_mag\"] = getattr(self.sae, \"r_mag\").data.clone()\n",
    "            # mask the weight\n",
    "            masked_weight = sae.r_mag[self.feature_idxs]\n",
    "            # set the weight\n",
    "            setattr(self.sae, \"r_mag\", nn.Parameter(masked_weight))\n",
    "\n",
    "            ## b_mag\n",
    "            self.original_weight[\"b_mag\"] = getattr(self.sae, \"b_mag\").data.clone()\n",
    "            # mask the weight\n",
    "            masked_weight = sae.b_mag[self.feature_idxs]\n",
    "            # set the weight\n",
    "            setattr(self.sae, \"b_mag\", nn.Parameter(masked_weight))\n",
    "        else:\n",
    "            raise (ValueError(\"Invalid architecture\"))\n",
    "\n",
    "        return self\n",
    "\n",
    "    def __exit__(self, exc_type, exc_value, traceback):\n",
    "\n",
    "        # set everything back to normal\n",
    "        for key, value in self.original_weight.items():\n",
    "            setattr(self.sae, key, nn.Parameter(value))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gc\n",
    "import torch\n",
    "\n",
    "gc.collect()\n",
    "torch.cuda.empty_cache()\n",
    "torch.set_grad_enabled(False)\n",
    "\n",
    "\n",
    "def my_function(sae_in):\n",
    "    # Your PyTorch code here\n",
    "    feature_idxs = list(range(2048))\n",
    "    with FeatureMaskingContext(sae, feature_idxs):\n",
    "        features = sae(sae_in)\n",
    "        print(features.mean())\n",
    "\n",
    "\n",
    "tokens = token_dataset[:64]\n",
    "_, cache = model.run_with_cache(\n",
    "    tokens, stop_at_layer=sae.cfg.hook_layer + 1, names_filter=sae.cfg.hook_name\n",
    ")\n",
    "sae_in = cache[sae.cfg.hook_name]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokens.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sae.W_dec.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext memray"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%memray_flamegraph --trace-python-allocators --leaks\n",
    "my_function(sae_in)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
