{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a target=\"_blank\" href=\"https://colab.research.google.com/github/ai-safety-foundation/sparse_autoencoder/blob/main/docs/content/demo.ipynb\">\n",
    "  <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
    "</a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Quick Start Training Demo\n",
    "\n",
    "This is a quick start demo to get training a SAE right away. All you need to do is choose a few\n",
    "hyperparameters (like the model to train on), and then set it off. By default it trains SAEs on all\n",
    "MLP layers from GPT2 small."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check if we're in Colab\n",
    "try:\n",
    "    import google.colab  # noqa: F401 # type: ignore\n",
    "\n",
    "    in_colab = True\n",
    "except ImportError:\n",
    "    in_colab = False\n",
    "\n",
    "#  Install if in Colab\n",
    "if in_colab:\n",
    "    %pip install sparse_autoencoder transformer_lens transformers wandb\n",
    "\n",
    "# Otherwise enable hot reloading in dev mode\n",
    "if not in_colab:\n",
    "    %load_ext autoreload\n",
    "    %autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "from sparse_autoencoder import (\n",
    "    ActivationResamplerHyperparameters,\n",
    "    Hyperparameters,\n",
    "    LossHyperparameters,\n",
    "    Method,\n",
    "    OptimizerHyperparameters,\n",
    "    Parameter,\n",
    "    PipelineHyperparameters,\n",
    "    SourceDataHyperparameters,\n",
    "    SourceModelHyperparameters,\n",
    "    sweep,\n",
    "    SweepConfig,\n",
    ")\n",
    "\n",
    "\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
    "os.environ[\"WANDB_NOTEBOOK_NAME\"] = \"demo.ipynb\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Hyperparameters"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Customize any hyperparameters you want below (by default we're sweeping over l1 coefficient and\n",
    "learning rate):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SweepConfig(parameters=Hyperparameters(\n",
       "    source_data=SourceDataHyperparameters(dataset_path=Parameter(value=alancooney/sae-monology-pile-uncopyrighted-tokenizer-gpt2), context_size=Parameter(value=128), dataset_column_name=Parameter(value=input_ids), dataset_dir=None, dataset_files=None, pre_download=Parameter(value=False), pre_tokenized=Parameter(value=True), tokenizer_name=None)\n",
       "    source_model=SourceModelHyperparameters(name=Parameter(value=gpt2-small), cache_names=Parameter(value=['blocks.0.hook_mlp_out', 'blocks.1.hook_mlp_out', 'blocks.2.hook_mlp_out', 'blocks.3.hook_mlp_out', 'blocks.4.hook_mlp_out', 'blocks.5.hook_mlp_out', 'blocks.6.hook_mlp_out', 'blocks.7.hook_mlp_out', 'blocks.8.hook_mlp_out', 'blocks.9.hook_mlp_out', 'blocks.10.hook_mlp_out', 'blocks.11.hook_mlp_out']), hook_dimension=Parameter(value=768), dtype=Parameter(value=float32))\n",
       "    activation_resampler=ActivationResamplerHyperparameters(resample_interval=Parameter(value=200000000), max_n_resamples=Parameter(value=4), n_activations_activity_collate=Parameter(value=100000000), resample_dataset_size=Parameter(value=200000), threshold_is_dead_portion_fires=Parameter(value=1e-06))\n",
       "    autoencoder=AutoencoderHyperparameters(expansion_factor=Parameter(value=2))\n",
       "    loss=LossHyperparameters(l1_coefficient=Parameter(max=0.01, min=0.004))\n",
       "    optimizer=OptimizerHyperparameters(lr=Parameter(max=0.001, min=1e-05), adam_beta_1=Parameter(value=0.9), adam_beta_2=Parameter(value=0.99), adam_weight_decay=Parameter(value=0.0), amsgrad=Parameter(value=False), fused=Parameter(value=False))\n",
       "    pipeline=PipelineHyperparameters(log_frequency=Parameter(value=100), source_data_batch_size=Parameter(value=16), train_batch_size=Parameter(value=1024), max_store_size=Parameter(value=300000), max_activations=Parameter(value=1000000000), checkpoint_frequency=Parameter(value=100000000), validation_frequency=Parameter(value=100000000), validation_n_activations=Parameter(value=8192))\n",
       "    random_seed=Parameter(value=49)\n",
       "), method=<Method.RANDOM: 'random'>, metric=Metric(name=train/loss/total_loss, goal=minimize), command=None, controller=None, description=None, earlyterminate=None, entity=None, imageuri=None, job=None, kind=None, name=None, program=None, project=None)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "n_layers_gpt2_small = 12\n",
    "\n",
    "sweep_config = SweepConfig(\n",
    "    parameters=Hyperparameters(\n",
    "        activation_resampler=ActivationResamplerHyperparameters(\n",
    "            resample_interval=Parameter(200_000_000),\n",
    "            n_activations_activity_collate=Parameter(100_000_000),\n",
    "            threshold_is_dead_portion_fires=Parameter(1e-6),\n",
    "            max_n_resamples=Parameter(4),\n",
    "            resample_dataset_size=Parameter(200_000),\n",
    "        ),\n",
    "        loss=LossHyperparameters(\n",
    "            l1_coefficient=Parameter(max=1e-2, min=4e-3),\n",
    "        ),\n",
    "        optimizer=OptimizerHyperparameters(\n",
    "            lr=Parameter(max=1e-3, min=1e-5),\n",
    "        ),\n",
    "        source_model=SourceModelHyperparameters(\n",
    "            name=Parameter(\"gpt2-small\"),\n",
    "            # Train in parallel on all MLP layers\n",
    "            cache_names=Parameter(\n",
    "                [f\"blocks.{layer}.hook_mlp_out\" for layer in range(n_layers_gpt2_small)]\n",
    "            ),\n",
    "            hook_dimension=Parameter(768),\n",
    "        ),\n",
    "        source_data=SourceDataHyperparameters(\n",
    "            dataset_path=Parameter(\"alancooney/sae-monology-pile-uncopyrighted-tokenizer-gpt2\"),\n",
    "            context_size=Parameter(128),\n",
    "            pre_tokenized=Parameter(value=True),\n",
    "        ),\n",
    "        pipeline=PipelineHyperparameters(\n",
    "            max_activations=Parameter(1_000_000_000),\n",
    "            checkpoint_frequency=Parameter(100_000_000),\n",
    "            validation_frequency=Parameter(100_000_000),\n",
    "            train_batch_size=Parameter(1024),\n",
    "            max_store_size=Parameter(300_000),\n",
    "        ),\n",
    "    ),\n",
    "    method=Method.RANDOM,\n",
    ")\n",
    "sweep_config"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run the sweep"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sweep(sweep_config=sweep_config)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  },
  "vscode": {
   "interpreter": {
    "hash": "31186ba1239ad81afeb3c631b4833e71f34259d3b92eebb37a9091b916e08620"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
