{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a target=\"_blank\" href=\"https://colab.research.google.com/github/ai-safety-foundation/sparse_autoencoder/blob/main/docs/content/flexible_demo.ipynb\">\n",
    "  <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
    "</a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Flexible Training Demo\n",
    "\n",
    "This demo shows you how to train a sparse autoencoder (SAE) on a\n",
    "[TransformerLens](https://github.com/neelnanda-io/TransformerLens) model. It replicates Neel Nanda's\n",
    "[comment on the Anthropic dictionary learning\n",
    "paper](https://transformer-circuits.pub/2023/monosemantic-features/index.html#comment-nanda).\n",
    "\n",
    "## Introduction\n",
    "\n",
    "The way this library works is that we provide all the components necessary to train a sparse\n",
    "autoencoder. For the most part, these are just standard PyTorch modules. For example `AdamWithReset` is\n",
    "just an extension of `torch.optim.Adam`, with a few extra bells and whistles that are needed for training a SAE\n",
    "(e.g. a method to reset the optimizer state when you are also resampling dead neurons).\n",
    "\n",
    "This is very flexible - it's easy for you to extend and change just one component if you want, just\n",
    "like you'd do with a standard PyTorch mode. It also means it's very easy to see what is going on\n",
    "under the hood. However to get you started, the following demo sets up a\n",
    "default SAE that uses the implementation that Neel Nanda used in his comment above.\n",
    "\n",
    "### Approach\n",
    "\n",
    "The approach is pretty simple - we run a training pipeline that alternates between generating\n",
    "activations from a *source model*, and training the *sparse autoencoder* model on these generated\n",
    "activations."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check if we're in Colab\n",
    "try:\n",
    "    import google.colab  # noqa: F401 # type: ignore\n",
    "\n",
    "    in_colab = True\n",
    "except ImportError:\n",
    "    in_colab = False\n",
    "\n",
    "#  Install if in Colab\n",
    "if in_colab:\n",
    "    %pip install sparse_autoencoder transformer_lens transformers wandb\n",
    "\n",
    "# Otherwise enable hot reloading in dev mode\n",
    "if not in_colab:\n",
    "    from IPython import get_ipython  # type: ignore\n",
    "\n",
    "    ip = get_ipython()\n",
    "    if ip is not None and ip.extension_manager is not None and not ip.extension_manager.loaded:\n",
    "        ip.extension_manager.load(\"autoreload\")  # type: ignore\n",
    "        %autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from pathlib import Path\n",
    "\n",
    "import torch\n",
    "from transformer_lens import HookedTransformer\n",
    "from transformer_lens.utils import get_device\n",
    "\n",
    "from sparse_autoencoder import (\n",
    "    ActivationResampler,\n",
    "    AdamWithReset,\n",
    "    L2ReconstructionLoss,\n",
    "    LearnedActivationsL1Loss,\n",
    "    LossReducer,\n",
    "    Pipeline,\n",
    "    PreTokenizedDataset,\n",
    "    SparseAutoencoder,\n",
    ")\n",
    "import wandb\n",
    "\n",
    "\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
    "\n",
    "device = get_device()\n",
    "print(f\"Using device: {device}\")  # You will need a GPU"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Hyperparameters"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The way this library works is that you can define your own hyper-parameters and then setup the\n",
    "underlying components with them. This is extremely flexible, but to help you get started we've\n",
    "included some common ones below along with some sensible defaults. You can also easily sweep through\n",
    "multiple hyperparameters with `wandb.sweep`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.random.manual_seed(49)\n",
    "\n",
    "hyperparameters = {\n",
    "    # Expansion factor is the number of features in the sparse representation, relative to the\n",
    "    # number of features in the original MLP layer. The original paper experimented with 1x to 256x,\n",
    "    # and we have found that 4x is a good starting point.\n",
    "    \"expansion_factor\": 4,\n",
    "    # L1 coefficient is the coefficient of the L1 regularization term (used to encourage sparsity).\n",
    "    \"l1_coefficient\": 3e-4,\n",
    "    # Adam parameters (set to the default ones here)\n",
    "    \"lr\": 1e-4,\n",
    "    \"adam_beta_1\": 0.9,\n",
    "    \"adam_beta_2\": 0.999,\n",
    "    \"adam_epsilon\": 1e-8,\n",
    "    \"adam_weight_decay\": 0.0,\n",
    "    # Batch sizes\n",
    "    \"train_batch_size\": 4096,\n",
    "    \"context_size\": 128,\n",
    "    # Source model hook point\n",
    "    \"source_model_name\": \"gelu-2l\",\n",
    "    \"source_model_dtype\": \"float32\",\n",
    "    \"source_model_hook_point\": \"blocks.0.hook_mlp_out\",\n",
    "    \"source_model_hook_point_layer\": 0,\n",
    "    # Train pipeline parameters\n",
    "    \"max_store_size\": 384 * 4096 * 2,\n",
    "    \"max_activations\": 2_000_000_000,\n",
    "    \"resample_frequency\": 122_880_000,\n",
    "    \"checkpoint_frequency\": 100_000_000,\n",
    "    \"validation_frequency\": 384 * 4096 * 2 * 100,  # Every 100 generations\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Source Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The source model is just a [TransformerLens](https://github.com/neelnanda-io/TransformerLens) model\n",
    "(see [here](https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)\n",
    "for a full list of supported models).\n",
    "\n",
    "In this example we're training a sparse autoencoder on the activations from the first MLP layer, so\n",
    "we'll also get some details about that hook point."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Source model setup with TransformerLens\n",
    "src_model = HookedTransformer.from_pretrained(\n",
    "    str(hyperparameters[\"source_model_name\"]), dtype=str(hyperparameters[\"source_model_dtype\"])\n",
    ")\n",
    "\n",
    "# Details about the activations we'll train the sparse autoencoder on\n",
    "autoencoder_input_dim: int = src_model.cfg.d_model  # type: ignore (TransformerLens typing is currently broken)\n",
    "\n",
    "f\"Source: {hyperparameters['source_model_name']}, \\\n",
    "    Hook: {hyperparameters['source_model_hook_point']}, \\\n",
    "    Features: {autoencoder_input_dim}\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Sparse Autoencoder"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can then setup the sparse autoencoder. The default model (`SparseAutoencoder`) is setup as per\n",
    "the original Anthropic paper [Towards Monosemanticity: Decomposing Language Models With Dictionary\n",
    "Learning ](https://transformer-circuits.pub/2023/monosemantic-features/index.html).\n",
    "\n",
    "However it's just a standard PyTorch model, so you can create your own model instead if you want to\n",
    "use a different architecture. To do this you just need to extend the `AbstractAutoencoder`, and\n",
    "optionally the underlying `AbstractEncoder`, `AbstractDecoder` and `AbstractOuterBias`. See these\n",
    "classes (which are fully documented) for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "expansion_factor = hyperparameters[\"expansion_factor\"]\n",
    "autoencoder = SparseAutoencoder(\n",
    "    n_input_features=autoencoder_input_dim,  # size of the activations we are autoencoding\n",
    "    n_learned_features=int(autoencoder_input_dim * expansion_factor),  # size of SAE\n",
    ").to(device)\n",
    "autoencoder"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll also want to setup an Optimizer and Loss function. In this case we'll also use the standard\n",
    "approach from the original Anthropic paper. However you can create your own loss functions and\n",
    "optimizers by extending `AbstractLoss` and `AbstractOptimizerWithReset` respectively."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# We use a loss reducer, which simply adds up the losses from the underlying loss functions.\n",
    "loss = LossReducer(\n",
    "    LearnedActivationsL1Loss(\n",
    "        l1_coefficient=float(hyperparameters[\"l1_coefficient\"]),\n",
    "    ),\n",
    "    L2ReconstructionLoss(),\n",
    ")\n",
    "loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = AdamWithReset(\n",
    "    params=autoencoder.parameters(),\n",
    "    named_parameters=autoencoder.named_parameters(),\n",
    "    lr=float(hyperparameters[\"lr\"]),\n",
    "    betas=(float(hyperparameters[\"adam_beta_1\"]), float(hyperparameters[\"adam_beta_2\"])),\n",
    "    eps=float(hyperparameters[\"adam_epsilon\"]),\n",
    "    weight_decay=float(hyperparameters[\"adam_weight_decay\"]),\n",
    "    has_components_dim=True,\n",
    ")\n",
    "optimizer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally we'll initialise an activation resampler."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "activation_resampler = ActivationResampler(\n",
    "    resample_interval=10_000,\n",
    "    n_activations_activity_collate=10_000,\n",
    "    max_n_resamples=5,\n",
    "    n_learned_features=autoencoder.n_learned_features,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Source dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is just a dataset of tokenized prompts, to be used in generating activations (which are in turn\n",
    "used to train the SAE)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "source_data = PreTokenizedDataset(\n",
    "    dataset_path=\"alancooney/sae-monology-pile-uncopyrighted-tokenizer-gpt2\",\n",
    "    context_size=int(hyperparameters[\"context_size\"]),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you initialise [wandb](https://wandb.ai/site), the pipeline will automatically log all metrics to\n",
    "wandb. However, we should pass in a dictionary with all of our hyperaparameters so they're on \n",
    "wandb. \n",
    "\n",
    "We strongly encourage users to make use of wandb in order to understand the training process."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint_path = Path(\"../../.checkpoints\")\n",
    "checkpoint_path.mkdir(exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Path(\".cache/\").mkdir(exist_ok=True)\n",
    "wandb.init(\n",
    "    project=\"sparse-autoencoder\",\n",
    "    dir=\".cache\",\n",
    "    config=hyperparameters,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "pipeline = Pipeline(\n",
    "    activation_resampler=activation_resampler,\n",
    "    autoencoder=autoencoder,\n",
    "    cache_names=[str(hyperparameters[\"source_model_hook_point\"])],\n",
    "    checkpoint_directory=checkpoint_path,\n",
    "    layer=int(hyperparameters[\"source_model_hook_point_layer\"]),\n",
    "    loss=loss,\n",
    "    optimizer=optimizer,\n",
    "    source_data_batch_size=6,\n",
    "    source_dataset=source_data,\n",
    "    source_model=src_model,\n",
    ")\n",
    "\n",
    "pipeline.run_pipeline(\n",
    "    train_batch_size=int(hyperparameters[\"train_batch_size\"]),\n",
    "    max_store_size=int(hyperparameters[\"max_store_size\"]),\n",
    "    max_activations=int(hyperparameters[\"max_activations\"]),\n",
    "    checkpoint_frequency=int(hyperparameters[\"checkpoint_frequency\"]),\n",
    "    validate_frequency=int(hyperparameters[\"validation_frequency\"]),\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  },
  "vscode": {
   "interpreter": {
    "hash": "31186ba1239ad81afeb3c631b4833e71f34259d3b92eebb37a9091b916e08620"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
