{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from e2e_sae.scripts.train_tlens_saes.run_train_tlens_saes import Config\n",
    "from e2e_sae.scripts.train_tlens_saes.run_train_tlens_saes import main as run_training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define a Config\n",
    "The sample config below will train a single SAE on layer 6 of gpt2 using an e2e loss.\n",
    "\n",
    "Note that this will take 10-11 hours on an A100 to run. See\n",
    "[e2e_sae/scripts/train_tlens_saes/tinystories_1M_e2e.yaml](../e2e_sae/scripts/train_tlens_saes/tinystories_1M_e2e.yaml)\n",
    "for a tinystories-1m config, or simply choose a smaller model to train on and adjust the\n",
    "n_ctx and dataset accordingly (some other pre-tokenized datasets can be found [here](https://huggingface.co/ANONYMIZED)).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = Config(\n",
    "    wandb_project=\"gpt2-e2e_play\",\n",
    "    wandb_run_name=None,  # If not set, will use a name based on important config values\n",
    "    wandb_run_name_prefix=\"\",\n",
    "    seed=0,\n",
    "    tlens_model_name=\"gpt2-small\",\n",
    "    tlens_model_path=None,\n",
    "    n_samples=400_000,\n",
    "    save_every_n_samples=None,\n",
    "    eval_every_n_samples=40_000,\n",
    "    eval_n_samples=500,\n",
    "    log_every_n_grad_steps=20,\n",
    "    collect_act_frequency_every_n_samples=40_000,\n",
    "    act_frequency_n_tokens=500_000,\n",
    "    batch_size=8,\n",
    "    effective_batch_size=16,  # Number of samples before each optimizer step\n",
    "    lr=5e-4,\n",
    "    lr_schedule=\"cosine\",\n",
    "    min_lr_factor=0.1,  # Minimum learning rate as a fraction of the initial learning rate\n",
    "    warmup_samples=20_000,  # Linear warmup over this many samples\n",
    "    max_grad_norm=10.0,  # Gradient norms get clipped to this value before optimizer steps\n",
    "    loss={\n",
    "        # Note that \"original acts\" below refers to the activations in a model without SAEs\n",
    "        \"sparsity\": {\n",
    "            \"p_norm\": 1.0,  # p value in Lp norm\n",
    "            \"coeff\": 1.5,  # Multiplies the Lp norm in the loss (sparsity coefficient)\n",
    "        },\n",
    "        \"in_to_orig\": None,  # Used for e2e+future recon. MSE between the input to the SAE and original acts\n",
    "        \"out_to_orig\": None,  # Not commonly used. MSE between the output of the SAE and original acts\n",
    "        \"out_to_in\": {\n",
    "            # Multiplies the MSE between the output and input of the SAE. Setting to 0 lets us track this\n",
    "            # loss during training without optimizing it\n",
    "            \"coeff\": 0.0,\n",
    "        },\n",
    "        \"logits_kl\": {\n",
    "            \"coeff\": 1.0,  # Multiplies the KL divergence between the logits of the SAE model and original model\n",
    "        },\n",
    "    },\n",
    "    train_data={\n",
    "        # See https://huggingface.co/ANONYMIZED for other pre-tokenized datasets\n",
    "        \"dataset_name\": \"ANONYMIZED/Skylion007-openwebtext-tokenizer-gpt2\",\n",
    "        \"is_tokenized\": True,\n",
    "        \"tokenizer_name\": \"gpt2\",\n",
    "        \"streaming\": True,\n",
    "        \"split\": \"train\",\n",
    "        \"n_ctx\": 1024,\n",
    "    },\n",
    "    eval_data={\n",
    "        # By default this will use a different seed to the training data, but can be set with `seed`\n",
    "        \"dataset_name\": \"ANONYMIZED/Skylion007-openwebtext-tokenizer-gpt2\",\n",
    "        \"is_tokenized\": True,\n",
    "        \"tokenizer_name\": \"gpt2\",\n",
    "        \"streaming\": True,\n",
    "        \"split\": \"train\",\n",
    "        \"n_ctx\": 1024,\n",
    "    },\n",
    "    saes={\n",
    "        \"retrain_saes\": False,  # Determines whether to continue training the SAEs in pretrained_sae_paths\n",
    "        \"pretrained_sae_paths\": None,  # Path or paths to pretrained SAEs\n",
    "        \"sae_positions\": [  # Position or positions to place SAEs in the model\n",
    "            \"blocks.6.hook_resid_pre\",\n",
    "        ],\n",
    "        \"dict_size_to_input_ratio\": 60.0,  # Size of the dictionary relative to the activations at the SAE positions\n",
    "    },\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "run_training(config)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "sp-env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
