{
  "cells": [
    {
      "cell_type": "markdown",
      "id": "45e8a0ab",
      "metadata": {},
      "source": [
        "# Small toy model experiments\n",
        "\n",
        "In this notebook, we explore the effect of L0 on SAEs when there are correlated features with a tiny toy model where everything is easily visible."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "8ef804a4",
      "metadata": {},
      "source": [
        "# Positive correlations\n",
        "\n",
        "We'll start with a simple toy model with 5 features, and a correlation matrix where all features are positively correlated with feature 0. All other features are uncorrelated with each other."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "8f1a13f9",
      "metadata": {},
      "outputs": [],
      "source": [
        "from functools import partial\n",
        "import torch\n",
        "from tqdm import tqdm\n",
        "\n",
        "from sparse_but_wrong.toy_models.get_training_batch import  get_training_batch\n",
        "from sparse_but_wrong.toy_models.toy_model import ToyModel\n",
        "from sparse_but_wrong.toy_models.plotting import plot_correlation_matrix\n",
        "from sparse_but_wrong.util import DEFAULT_DEVICE\n",
        "\n",
        "tqdm._instances.clear()  # type: ignore\n",
        "\n",
        "# Set up the topy model to have L0=2\n",
        "feat_probs = torch.ones(5) * 2 / 5\n",
        "pos_correlations = torch.zeros(5, 5)\n",
        "pos_correlations[:, 0] = 0.4\n",
        "pos_correlations[0, :] = 0.4\n",
        "pos_correlations.fill_diagonal_(1.0)\n",
        "\n",
        "toy_model = ToyModel(num_feats=5, hidden_dim=20).to(DEFAULT_DEVICE)\n",
        "\n",
        "\n",
        "generate_batch_pos = partial(\n",
        "    get_training_batch,\n",
        "    firing_probabilities=feat_probs,\n",
        "    std_firing_magnitudes=torch.ones_like(feat_probs) * 0.15,\n",
        "    correlation_matrix=pos_correlations,\n",
        ")\n",
        "\n",
        "\n",
        "plot_correlation_matrix(pos_correlations, save_path=\"plots/toy_setup_small/positive_correlation_matrix.pdf\")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "8aef584d",
      "metadata": {},
      "source": [
        "## Finding the True L0\n",
        "\n",
        "Next, we calculate the true L0 for this dataset (spoiler: it's ~2)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "2d2f1ea7",
      "metadata": {},
      "outputs": [],
      "source": [
        "sample = generate_batch_pos(100_000)\n",
        "true_l0 = (sample > 0).float().sum(dim=-1).mean()\n",
        "print(f\"True L0: {true_l0}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "04a8470a",
      "metadata": {},
      "source": [
        "## Training an SAE with the right L0\n",
        "\n",
        "Next, let's train an SAE with the correct L0 and verify it can perfectly recover the underlying true features"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "8e8bd749",
      "metadata": {},
      "outputs": [],
      "source": [
        "from sae_lens import BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig\n",
        "from sparse_but_wrong.toy_models.train_toy_sae import train_toy_sae\n",
        "\n",
        "cfg = BatchTopKTrainingSAEConfig(k=2, d_in=toy_model.embed.weight.shape[0], d_sae=5)\n",
        "sae_full = BatchTopKTrainingSAE(cfg)\n",
        "\n",
        "# NOTE: occasionaly this gets stuck in poor local minima. If this happens, try rerunning and it should converge properly.\n",
        "train_toy_sae(sae_full, toy_model, generate_batch_pos)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "af294427",
      "metadata": {},
      "outputs": [],
      "source": [
        "from sparse_but_wrong.toy_models.plotting import plot_sae_feat_cos_sims, plot_sae_feat_cos_sims_seaborn\n",
        "\n",
        "plot_sae_feat_cos_sims(sae_full, toy_model, \"SAE L0 = True L0\", reorder_features=True, dtick=5)\n",
        "plot_sae_feat_cos_sims_seaborn(sae_full, toy_model, title=\"SAE L0 = True L0\", reorder_features=True, decoder_only=True, width=5, height=2, decoder_title=None, save_path=\"plots/toy_l0_small/pos_corr_sae_l0_eq_true_l0_decoder_cos_sims.pdf\")\n",
        "plot_sae_feat_cos_sims_seaborn(sae_full, toy_model, title=\"SAE L0 = True L0\", reorder_features=True, decoder_only=True, width=5, height=2, decoder_title=None, save_path=\"plots/toy_l0_small/pos_corr_sae_l0_eq_true_l0_decoder_cos_sims.png\")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "3dc66a9c",
      "metadata": {},
      "source": [
        "As we can see above, the SAE has learned the correct features despite the feature correlations."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "dde3bc76",
      "metadata": {},
      "source": [
        "## What if reduce the L0 slightly below the number of true features?\n",
        "\n",
        "If we lower the L0, then then SAE will need to reconstruct the input with less latents than the number of true features required. We'll reduce the SAE L0 by 10%, from 2 to 1.8 (BatchTopK SAEs have no problem with fractional L0s, which is nice for our experiments). We'll further initialize the SAE to the ground truth correct solution, so we can be certain that anything that results is a result of gradient pressure rather than just a poor local minimum."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "337325f4",
      "metadata": {},
      "outputs": [],
      "source": [
        "from sae_lens import BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig\n",
        "from sparse_but_wrong.toy_models.train_toy_sae import train_toy_sae\n",
        "from sparse_but_wrong.toy_models.initialization import init_sae_to_match_model\n",
        "\n",
        "cfg = BatchTopKTrainingSAEConfig(k=1.8, d_in=toy_model.embed.weight.shape[0], d_sae=5)\n",
        "sae_narrow = BatchTopKTrainingSAE(cfg)\n",
        "\n",
        "init_sae_to_match_model(sae_narrow, toy_model)\n",
        "\n",
        "train_toy_sae(sae_narrow, toy_model, generate_batch_pos)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "cece73e4",
      "metadata": {},
      "outputs": [],
      "source": [
        "import plotly.express as px\n",
        "from sparse_but_wrong.toy_models.plotting import plot_sae_feat_cos_sims, plot_sae_feat_cos_sims_seaborn\n",
        "\n",
        "plot_sae_feat_cos_sims(sae_narrow, toy_model, \"SAE L0 < True L0\", reorder_features=True)\n",
        "px.imshow(pos_correlations, width=400, height=400, color_continuous_scale=\"RdBu\", zmin=-1, zmax=1, title=\"Feature correlation matrix\", origin=\"lower\").show()\n",
        "plot_sae_feat_cos_sims_seaborn(sae_narrow, toy_model, title=\"SAE L0 $<$ True L0\", reorder_features=True, decoder_only=True, width=5, height=2, decoder_title=None, save_path=\"plots/toy_l0_small/pos_corr_sae_l0_lt_true_l0_decoder_cos_sims.pdf\")\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "00ea495b",
      "metadata": {},
      "source": [
        "We see that all SAE latents contain a strong positive component of feature 0, but no component of any other feautre, exactly matching the correlation matrix. Clearly the correlations are what is causing the SAE to mix feature components together."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "a0b7a631",
      "metadata": {},
      "source": [
        "# Negative correlations\n",
        "\n",
        "Next, we'll change the correlation matrix so all features are negatively correlated with feature 0 instead of positively correlated. We'll keep the same toy model as before, we're just changing feature firing correlations."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "8670802c",
      "metadata": {},
      "outputs": [],
      "source": [
        "from functools import partial\n",
        "import torch\n",
        "\n",
        "from sparse_but_wrong.toy_models.get_training_batch import  get_training_batch\n",
        "from sparse_but_wrong.toy_models.plotting import plot_correlation_matrix\n",
        "\n",
        "# Set up the topy model to have L0=2\n",
        "feat_probs = torch.ones(5) * 2 / 5\n",
        "neg_correlations = torch.zeros(5, 5)\n",
        "neg_correlations[:, 0] = -0.4\n",
        "neg_correlations[0, :] = -0.4\n",
        "neg_correlations.fill_diagonal_(1.0)\n",
        "\n",
        "\n",
        "generate_batch_neg = partial(\n",
        "    get_training_batch,\n",
        "    firing_probabilities=feat_probs,\n",
        "    std_firing_magnitudes=torch.ones_like(feat_probs) * 0.15,\n",
        "    correlation_matrix=neg_correlations,\n",
        ")\n",
        "\n",
        "plot_correlation_matrix(neg_correlations, save_path=\"plots/toy_setup_small/negative_correlation_matrix.pdf\")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "167b875f",
      "metadata": {},
      "source": [
        "## Finding the True L0\n",
        "\n",
        "Let's double-check that the true L0 is still 2, as before."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "048132c2",
      "metadata": {},
      "outputs": [],
      "source": [
        "sample = generate_batch_neg(100_000)\n",
        "true_l0 = (sample > 0).float().sum(dim=-1).mean()\n",
        "print(f\"True L0: {true_l0}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "88d79850",
      "metadata": {},
      "source": [
        "## Training an SAE with the right L0\n",
        "\n",
        "Next, let's train an SAE with the correct L0 and verify it can perfectly recover the underlying true features"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "a866615e",
      "metadata": {},
      "outputs": [],
      "source": [
        "from sae_lens import BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig\n",
        "from sparse_but_wrong.toy_models.train_toy_sae import train_toy_sae\n",
        "\n",
        "cfg = BatchTopKTrainingSAEConfig(k=2, d_in=toy_model.embed.weight.shape[0], d_sae=5)\n",
        "sae_full = BatchTopKTrainingSAE(cfg)\n",
        "\n",
        "# NOTE: occasionaly this gets stuck in poor local minima. If this happens, try rerunning and it should converge properly.\n",
        "train_toy_sae(sae_full, toy_model, generate_batch_neg)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "9393eee1",
      "metadata": {},
      "outputs": [],
      "source": [
        "from sparse_but_wrong.toy_models.plotting import plot_sae_feat_cos_sims, plot_sae_feat_cos_sims_seaborn\n",
        "\n",
        "plot_sae_feat_cos_sims(sae_full, toy_model, \"SAE L0 = True L0\", reorder_features=True, dtick=5)\n",
        "plot_sae_feat_cos_sims_seaborn(sae_full, toy_model, title=\"SAE L0 = True L0\", reorder_features=True, decoder_only=True, width=5, height=2, decoder_title=None, save_path=\"plots/toy_l0_small/neg_corr_sae_l0_eq_true_l0_decoder_cos_sims.pdf\")\n",
        "plot_sae_feat_cos_sims_seaborn(sae_full, toy_model, title=\"SAE L0 = True L0\", reorder_features=True, decoder_only=True, width=5, height=2, decoder_title=None, save_path=\"plots/toy_l0_small/neg_corr_sae_l0_eq_true_l0_decoder_cos_sims.png\")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "9562a826",
      "metadata": {},
      "source": [
        "As we can see above, the SAE has learned the correct features despite the feature correlations."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "2fc547f4",
      "metadata": {},
      "source": [
        "## What if reduce the L0 slightly below the number of true features?\n",
        "\n",
        "If we lower the L0, then then SAE will need to reconstruct the input with less latents than the number of true features required. We'll reduce the SAE L0 by 10%, from 2 to 1.8 (BatchTopK SAEs have no problem with fractional L0s, which is nice for our experiments). We'll further initialize the SAE to the ground truth correct solution, so we can be certain that anything that results is a result of gradient pressure rather than just a poor local minimum."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "bde8f203",
      "metadata": {},
      "outputs": [],
      "source": [
        "from sae_lens import BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig\n",
        "from sparse_but_wrong.toy_models.train_toy_sae import train_toy_sae\n",
        "from sparse_but_wrong.toy_models.initialization import init_sae_to_match_model\n",
        "\n",
        "cfg = BatchTopKTrainingSAEConfig(k=1.8, d_in=toy_model.embed.weight.shape[0], d_sae=5)\n",
        "sae_narrow = BatchTopKTrainingSAE(cfg)\n",
        "\n",
        "init_sae_to_match_model(sae_narrow, toy_model)\n",
        "\n",
        "train_toy_sae(sae_narrow, toy_model, generate_batch_neg)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "479345a3",
      "metadata": {},
      "outputs": [],
      "source": [
        "import plotly.express as px\n",
        "from sparse_but_wrong.toy_models.plotting import plot_sae_feat_cos_sims, plot_sae_feat_cos_sims_seaborn\n",
        "\n",
        "plot_sae_feat_cos_sims(sae_narrow, toy_model, \"SAE L0 < True L0\", reorder_features=True)\n",
        "px.imshow(neg_correlations, width=400, height=400, color_continuous_scale=\"RdBu\", zmin=-1, zmax=1, title=\"Feature correlation matrix\", origin=\"lower\").show()\n",
        "plot_sae_feat_cos_sims_seaborn(sae_narrow, toy_model, title=\"SAE L0 $<$ True L0\", reorder_features=True, decoder_only=True, width=5, height=2, decoder_title=None, save_path=\"plots/toy_l0_small/neg_corr_sae_l0_lt_true_l0_decoder_cos_sims.pdf\")"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": ".venv",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.4"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}
