{
  "cells": [
    {
      "cell_type": "markdown",
      "id": "45e8a0ab",
      "metadata": {},
      "source": [
        "# Main toy model experiments\n",
        "\n",
        "In this notebook, we explore the effect of L0 on SAEs when there are correlated features."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "8f1a13f9",
      "metadata": {},
      "outputs": [],
      "source": [
        "from functools import partial\n",
        "from pathlib import Path\n",
        "import torch\n",
        "import pandas as pd\n",
        "from tqdm import tqdm\n",
        "import seaborn as sns\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "from sparse_but_wrong.toy_models.get_training_batch import generate_random_correlation_matrix, get_training_batch\n",
        "from sparse_but_wrong.toy_models.toy_model import ToyModel\n",
        "from sparse_but_wrong.toy_models.plotting import SEABORN_RC_CONTEXT\n",
        "from sparse_but_wrong.util import DEFAULT_DEVICE\n",
        "\n",
        "tqdm._instances.clear()  # type: ignore\n",
        "\n",
        "feat_probs = 0.345 * (50 - torch.arange(50) - 1) / 50 + 0.05\n",
        "\n",
        "if Path(\"correlations.pt\").exists():\n",
        "    print(\"Loading correlations from disk\")\n",
        "    correlations = torch.load(\"correlations.pt\")\n",
        "else:\n",
        "    correlations = generate_random_correlation_matrix(\n",
        "        correlation_strength_range=(0.3, 0.9),\n",
        "        num_features=50,\n",
        "        seed=42,\n",
        "    )\n",
        "    torch.save(correlations, \"correlations.pt\")\n",
        "\n",
        "indices = torch.arange(50) + 1\n",
        "df = pd.DataFrame({\n",
        "    \"P_i\": feat_probs,\n",
        "    \"feature\": map(str, indices.tolist()),\n",
        "})\n",
        "\n",
        "\n",
        "if Path(\"toy_model.pt\").exists():\n",
        "    print(\"Loading toy model from disk\")\n",
        "    toy_model = torch.load(\"toy_model.pt\", weights_only=False)\n",
        "else:\n",
        "    toy_model = ToyModel(num_feats=50, hidden_dim=100).to(DEFAULT_DEVICE)\n",
        "    torch.save(toy_model, \"toy_model.pt\")\n",
        "\n",
        "\n",
        "generate_batch = partial(\n",
        "    get_training_batch,\n",
        "    firing_probabilities=feat_probs,\n",
        "    std_firing_magnitudes=torch.ones_like(feat_probs) * 0.15,\n",
        "    correlation_matrix=correlations,\n",
        ")\n",
        "\n",
        "Path(\"plots/toy_setup\").mkdir(parents=True, exist_ok=True)\n",
        "\n",
        "plt.rcParams.update({\"figure.dpi\": 150})\n",
        "with plt.rc_context(SEABORN_RC_CONTEXT):\n",
        "    plt.figure(figsize=(3, 2))\n",
        "    sns.barplot(data=df, x=\"feature\", y=\"P_i\")\n",
        "    plt.xlabel(\"Feature\")\n",
        "    plt.ylabel(\"$P_i$\")\n",
        "    plt.title(\"Feature firing probabilities $P_i$\")\n",
        "    # Increase tick spacing to prevent overlapping\n",
        "    plt.xticks(range(0, len(df), 5), [str(i) for i in range(0, len(df), 5)])  # Show every 5th tick, 0-indexed\n",
        "    plt.tight_layout()\n",
        "    plt.savefig(\"plots/toy_setup/toy_model_feature_firing_probabilities.pdf\")\n",
        "    plt.show()\n",
        "\n",
        "plt.rcParams.update({\"figure.dpi\": 150})\n",
        "with plt.rc_context(SEABORN_RC_CONTEXT):\n",
        "    plt.figure(figsize=(2.5, 2))\n",
        "    sns.heatmap(correlations, cmap=\"RdBu\", center=0, vmin=-1, vmax=1)\n",
        "    plt.xlabel(\"Feature\")\n",
        "    plt.ylabel(\"Feature\")\n",
        "    plt.title(\"Feature correlation matrix\")\n",
        "    # Increase tick spacing to prevent overlapping\n",
        "    plt.xticks(range(0, len(correlations), 10), [str(i) for i in range(0, len(correlations), 10)])  # Show every 10th tick, 0-indexed\n",
        "    plt.yticks(range(0, len(correlations), 10), [str(i) for i in range(0, len(correlations), 10)])  # Show every 10th tick, 0-indexed\n",
        "    # plt.tight_layout()\n",
        "    plt.savefig(\"plots/toy_setup/toy_model_correlation_matrix.pdf\")\n",
        "    plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "8aef584d",
      "metadata": {},
      "source": [
        "## Finding the True L0\n",
        "\n",
        "Next, we calculate the true L0 for this dataset (spoiler: it's ~11)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "2d2f1ea7",
      "metadata": {},
      "outputs": [],
      "source": [
        "sample = generate_batch(100_000)\n",
        "true_l0 = (sample > 0).float().sum(dim=-1).mean()\n",
        "print(f\"True L0: {true_l0}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "04a8470a",
      "metadata": {},
      "source": [
        "## Training an SAE with the right L0\n",
        "\n",
        "Next, let's train an SAE with the correct L0 and verify it can perfectly recover the underlying true features"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "8e8bd749",
      "metadata": {},
      "outputs": [],
      "source": [
        "from sae_lens import BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig\n",
        "from sparse_but_wrong.toy_models.train_toy_sae import train_toy_sae\n",
        "\n",
        "cfg = BatchTopKTrainingSAEConfig(k=11, d_in=toy_model.embed.weight.shape[0], d_sae=50)\n",
        "sae_full = BatchTopKTrainingSAE(cfg)\n",
        "\n",
        "# NOTE: occasionaly this gets stuck in poor local minima. If this happens, try rerunning and it should converge properly.\n",
        "train_toy_sae(sae_full, toy_model, generate_batch)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "af294427",
      "metadata": {},
      "outputs": [],
      "source": [
        "from sparse_but_wrong.toy_models.plotting import plot_sae_feat_cos_sims, plot_sae_feat_cos_sims_seaborn\n",
        "\n",
        "plot_sae_feat_cos_sims(sae_full, toy_model, \"SAE L0 = True L0\", reorder_features=True, dtick=5)\n",
        "plot_sae_feat_cos_sims_seaborn(sae_full, toy_model, title=\"SAE L0 = True L0\", reorder_features=True, decoder_only=True, width=5, height=2, dtick=5, decoder_title=None, save_path=\"plots/toy_l0/sae_l0_eq_true_l0_decoder_cos_sims.pdf\")\n",
        "plot_sae_feat_cos_sims_seaborn(sae_full, toy_model, title=\"SAE L0 = True L0\", reorder_features=True, decoder_only=True, width=5, height=2, dtick=5, decoder_title=None, save_path=\"plots/toy_l0/sae_l0_eq_true_l0_decoder_cos_sims.png\")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "3dc66a9c",
      "metadata": {},
      "source": [
        "As we can see above, the SAE has learned the correct features."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "dde3bc76",
      "metadata": {},
      "source": [
        "## What if reduce the L0 below the number of true features?\n",
        "\n",
        "If we lower the L0, then then SAE will need to reconstruct the input with less latents than the number of true features required. How will it handle this? We set L0=9 below."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "337325f4",
      "metadata": {},
      "outputs": [],
      "source": [
        "from sae_lens import BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig\n",
        "from sparse_but_wrong.toy_models.train_toy_sae import train_toy_sae\n",
        "\n",
        "cfg = BatchTopKTrainingSAEConfig(k=9, d_in=toy_model.embed.weight.shape[0], d_sae=50)\n",
        "sae_narrow = BatchTopKTrainingSAE(cfg)\n",
        "\n",
        "train_toy_sae(sae_narrow, toy_model, generate_batch)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "cece73e4",
      "metadata": {},
      "outputs": [],
      "source": [
        "from sparse_but_wrong.toy_models.plotting import plot_sae_feat_cos_sims, plot_sae_feat_cos_sims_seaborn\n",
        "\n",
        "plot_sae_feat_cos_sims(sae_narrow, toy_model, \"SAE L0 < True L0\", reorder_features=True, dtick=5)\n",
        "plot_sae_feat_cos_sims_seaborn(sae_narrow, toy_model, title=\"SAE L0 $<$ True L0\", reorder_features=True, decoder_only=True, width=5, height=2, dtick=5, decoder_title=None, save_path=\"plots/toy_l0/sae_l0_lt_true_l0_decoder_cos_sims.pdf\")\n"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "aeaeda3a",
      "metadata": {},
      "source": [
        "The SAE is now broken - we no longer have a clear latent tracking feature 0, even though this is the most important feature! Instead every latent gets a component of feature 0 merged into it. The latents tracking the feature with the highest correlation with feature 0 are the most messed up by this. Essentially, the SAE is trying to \"cheat\" and reconstruct feature 0 without having to dedicate a latent to it.\n",
        "\n",
        "This breaks all the latents of the SAE by hedging a component of feautre 0 into them - but this broken behavior will result in a better MSE score than if the SAE learned all latents correctly, as we'll see next."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "d5c0eb2b",
      "metadata": {},
      "source": [
        "### MSE comparison: correct SAE vs low L0 SAE\n",
        "\n",
        "Why does the SAE learn these broken latents instead of learning the correct representations of features and just selecting the top 9 instead of the top 11? Let's compare the MSE loss from running the correct SAE we trained with L0=5 with L0=4 and the broken SAE that we trained explicitly with L0=4.\n",
        "\n",
        "We'll create 100k training samples, run them through our correct SAE modified with L0=4 and compare the MSE of this SAE with the broken SAE trained explicitly with L0=4."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "a0ec2762",
      "metadata": {},
      "outputs": [],
      "source": [
        "sample_feats = toy_model(generate_batch(100_000))\n",
        "\n",
        "# set k=9 for our correct SAE trained with k=11\n",
        "sae_full.activation_fn.k = 9  # type: ignore[attr-defined]\n",
        "\n",
        "correct_sae_mse = (sae_full(sample_feats) - sample_feats).pow(2).sum(dim=-1).mean().item()\n",
        "narrow_sae_mse = (sae_narrow(sample_feats) - sample_feats).pow(2).sum(dim=-1).mean().item()\n",
        "\n",
        "# reset k=5 for our correct SAE\n",
        "sae_full.activation_fn.k = 11 # type: ignore[attr-defined]\n",
        "\n",
        "print(f\"Correct SAE MSE: {correct_sae_mse:.2f}\")\n",
        "print(f\"Narrow SAE MSE: {narrow_sae_mse:.2f}\")\n"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "4826fb06",
      "metadata": {},
      "source": [
        "The correct SAE run at L0=4 results in 2x the MSE loss of the broken SAE trained at k=4!"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "f8551162",
      "metadata": {},
      "source": [
        "## Let's drop the L0 even further!\n",
        "\n",
        "What if we drop the L0 further, so k=5?"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "8927e492",
      "metadata": {},
      "outputs": [],
      "source": [
        "from sae_lens import BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig\n",
        "from sparse_but_wrong.toy_models.train_toy_sae import train_toy_sae\n",
        "\n",
        "cfg = BatchTopKTrainingSAEConfig(k=5, d_in=toy_model.embed.weight.shape[0], d_sae=50)\n",
        "sae_narrower = BatchTopKTrainingSAE(cfg)\n",
        "\n",
        "train_toy_sae(sae_narrower, toy_model, generate_batch)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "d794291e",
      "metadata": {},
      "outputs": [],
      "source": [
        "from sparse_but_wrong.toy_models.plotting import plot_sae_feat_cos_sims\n",
        "\n",
        "plot_sae_feat_cos_sims(sae_narrower, toy_model, \"SAE L0 << True L0\", reorder_features=True, dtick=5)\n",
        "plot_sae_feat_cos_sims_seaborn(sae_narrower, toy_model, title=\"SAE L0 $\\\\ll$ True L0\", reorder_features=True, decoder_only=True, width=5, height=2, dtick=5, decoder_title=None, save_path=\"plots/toy_l0/sae_l0_llt_true_l0_decoder_cos_sims.pdf\")\n",
        "plot_sae_feat_cos_sims_seaborn(sae_narrower, toy_model, title=\"SAE L0 $<$ True L0\", reorder_features=True, decoder_only=True, width=5, height=2, dtick=5, decoder_title=None, save_path=\"plots/toy_l0/sae_l0_llt_true_l0_decoder_cos_sims_v2.pdf\")\n",
        "plot_sae_feat_cos_sims_seaborn(sae_narrower, toy_model, title=\"SAE L0 $\\\\ll$ True L0\", reorder_features=True, decoder_only=True, width=5, height=2, dtick=5, decoder_title=None, save_path=\"plots/toy_l0/sae_l0_llt_true_l0_decoder_cos_sims.png\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "b13e89a1",
      "metadata": {},
      "outputs": [],
      "source": [
        "sample_feats = toy_model(generate_batch(100_000))\n",
        "\n",
        "# set k=9 for our correct SAE trained with k=11\n",
        "sae_full.activation_fn.k = 5  # type: ignore[attr-defined]\n",
        "\n",
        "correct_sae_mse = (sae_full(sample_feats) - sample_feats).pow(2).sum(dim=-1).mean().item()\n",
        "narrower_sae_mse = (sae_narrower(sample_feats) - sample_feats).pow(2).sum(dim=-1).mean().item()\n",
        "\n",
        "# reset k=5 for our correct SAE\n",
        "sae_full.activation_fn.k = 11  # type: ignore[attr-defined]\n",
        "\n",
        "print(f\"Correct SAE MSE: {correct_sae_mse:.2f}\")\n",
        "print(f\"Narrower SAE MSE: {narrower_sae_mse:.2f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "c1035f4d",
      "metadata": {},
      "source": [
        "As expected, the SAE is even worse now! The lower-frequency features are still reconstructed mostly correctly aside from hedging a component of feature 0, but now even more of the earlier latents are messed up as the SAE finds degenerate ways to improve its reconstruction error despite not having enough L0 to do so correctly. It's hard to see from the plot, but the magnitude of the hedging of feature 0 is also higher now."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "bc816d9d",
      "metadata": {},
      "source": [
        "## What if we initialize the SAE to the ground truth SAE?\n",
        "\n",
        "Next, let's verify that what we're seeing is a result of gradient pressure rather than just being a poor local minimum. We'll initialize the SAE to the ground truth SAE and see if further training results in polysemantic latents."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "c7ad1d58",
      "metadata": {},
      "outputs": [],
      "source": [
        "from sae_lens import BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig\n",
        "from sparse_but_wrong.toy_models.initialization import init_sae_to_match_model\n",
        "from sparse_but_wrong.toy_models.train_toy_sae import train_toy_sae\n",
        "\n",
        "cfg = BatchTopKTrainingSAEConfig(k=5, d_in=toy_model.embed.weight.shape[0], d_sae=50)\n",
        "sae_narrower_gt = BatchTopKTrainingSAE(cfg)\n",
        "\n",
        "# Initialize this to the ground truth correct solution\n",
        "init_sae_to_match_model(sae_narrower_gt, toy_model)\n",
        "\n",
        "train_toy_sae(sae_narrower_gt, toy_model, generate_batch)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "09ce2807",
      "metadata": {},
      "outputs": [],
      "source": [
        "from sparse_but_wrong.toy_models.plotting import plot_sae_feat_cos_sims\n",
        "\n",
        "plot_sae_feat_cos_sims(sae_narrower_gt, toy_model, \"Finetuned SAE, L0 << True L0\", reorder_features=True, dtick=5)\n",
        "plot_sae_feat_cos_sims_seaborn(sae_narrower_gt, toy_model, title=\"Finetuned SAE, L0 $\\\\ll$ True L0\", reorder_features=True, decoder_only=True, width=5, height=2, dtick=5, decoder_title=None, save_path=\"plots/toy_l0/sae_gt_l0_llt_true_l0_decoder_cos_sims.pdf\")\n",
        "plot_sae_feat_cos_sims_seaborn(sae_narrower_gt, toy_model, title=\"Finetuned SAE, L0 $<$ True L0\", reorder_features=True, decoder_only=True, width=5, height=2, dtick=5, decoder_title=None, save_path=\"plots/toy_l0/sae_gt_l0_llt_true_l0_decoder_cos_sims_v2.pdf\")\n",
        "plot_sae_feat_cos_sims_seaborn(sae_narrower_gt, toy_model, title=\"Finetuned SAE, L0 $\\\\ll$ True L0\", reorder_features=True, decoder_only=True, width=5, height=2, dtick=5, decoder_title=None, save_path=\"plots/toy_l0/sae_gt_l0_llt_true_l0_decoder_cos_sims.png\")\n"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "d7cb3d70",
      "metadata": {},
      "source": [
        "## What if we increase the L0 beyond the number of true features?\n",
        "\n",
        "If we increase the L0, then then SAE will have more latents than necessary to reconstruct the input. Will the SAE still do the right thing? We set L0=14 below."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "5b31f363",
      "metadata": {},
      "outputs": [],
      "source": [
        "from sae_lens import BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig\n",
        "from sparse_but_wrong.toy_models.train_toy_sae import train_toy_sae\n",
        "\n",
        "cfg = BatchTopKTrainingSAEConfig(k=14, d_in=toy_model.embed.weight.shape[0], d_sae=50)\n",
        "sae_wide = BatchTopKTrainingSAE(cfg)\n",
        "\n",
        "train_toy_sae(sae_wide, toy_model, generate_batch)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "845099b4",
      "metadata": {},
      "outputs": [],
      "source": [
        "from sparse_but_wrong.toy_models.plotting import plot_sae_feat_cos_sims, plot_sae_feat_cos_sims_seaborn\n",
        "\n",
        "plot_sae_feat_cos_sims(sae_wide, toy_model, \"SAE L0 > True L0\", reorder_features=True, dtick=5)\n",
        "plot_sae_feat_cos_sims_seaborn(sae_wide, toy_model, title=\"SAE L0 $>$ True L0\", reorder_features=True, decoder_only=True, width=5, height=2, dtick=5, decoder_title=None, save_path=\"plots/toy_l0/sae_l0_gt_true_l0_decoder_cos_sims.pdf\")\n"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "1fd6ec34",
      "metadata": {},
      "source": [
        "This is still mostly correct - we don't see any hedging at least. But the SAE is finding more degenerate local minima than it did before. Let's see what happens if we increase the L0 even more. "
      ]
    },
    {
      "cell_type": "markdown",
      "id": "837ac8f6",
      "metadata": {},
      "source": [
        "## Increase L0 even more!\n",
        "\n",
        "Let's see what happens if we increase the L0 to 18"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "5407ceca",
      "metadata": {},
      "outputs": [],
      "source": [
        "from sae_lens import BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig\n",
        "from sparse_but_wrong.toy_models.train_toy_sae import train_toy_sae\n",
        "\n",
        "cfg = BatchTopKTrainingSAEConfig(k=18, d_in=toy_model.embed.weight.shape[0], d_sae=50)\n",
        "sae_wider = BatchTopKTrainingSAE(cfg)\n",
        "\n",
        "\n",
        "train_toy_sae(sae_wider, toy_model, generate_batch)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "224e63ae",
      "metadata": {},
      "outputs": [],
      "source": [
        "from sparse_but_wrong.toy_models.plotting import plot_sae_feat_cos_sims, plot_sae_feat_cos_sims_seaborn\n",
        "\n",
        "plot_sae_feat_cos_sims(sae_wider, toy_model, \"SAE L0 >> True L0\", reorder_features=True, dtick=5)\n",
        "plot_sae_feat_cos_sims_seaborn(sae_wider, toy_model, title=\"SAE L0 $\\\\gg$ True L0\", reorder_features=True, decoder_only=True, width=5, height=2, dtick=5, decoder_title=None, save_path=\"plots/toy_l0/sae_l0_ggt_true_l0_decoder_cos_sims.pdf\")\n",
        "plot_sae_feat_cos_sims_seaborn(sae_wider, toy_model, title=\"SAE L0 $>$ True L0\", reorder_features=True, decoder_only=True, width=5, height=2, dtick=5, decoder_title=None, save_path=\"plots/toy_l0/sae_l0_ggt_true_l0_decoder_cos_sims_v2.pdf\")\n",
        "plot_sae_feat_cos_sims_seaborn(sae_wider, toy_model, title=\"SAE L0 $\\\\gg$ True L0\", reorder_features=True, decoder_only=True, width=5, height=2, dtick=5, decoder_title=None, save_path=\"plots/toy_l0/sae_l0_ggt_true_l0_decoder_cos_sims.png\")\n"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "da607817",
      "metadata": {},
      "source": [
        "As expected, the SAE is even worse now, using its extra capacity to allow some more degenerate solutions. There is still notably no hedging though, but the SAE is clearly not doing what we want."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "f14d046f",
      "metadata": {},
      "source": [
        "# Can we detect when we're at the correct L0?\n",
        "\n",
        "Let's train SAEs at a range of L0, ranging from below the correct L0 to above the correct L0"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "ca71be7a",
      "metadata": {},
      "outputs": [],
      "source": [
        "from collections import defaultdict\n",
        "from pathlib import Path\n",
        "from sae_lens import BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig\n",
        "from sparse_but_wrong.toy_models.train_toy_sae import train_toy_sae\n",
        "\n",
        "saes_by_k = defaultdict(list)\n",
        "for seed in [0, 1, 2, 3, 4]:\n",
        "    for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25]:\n",
        "        Path(f\"saes_by_k/seed_{seed}\").mkdir(parents=True, exist_ok=True)\n",
        "        if Path(f\"saes_by_k/seed_{seed}/{k}\").exists():\n",
        "            print(f\"Loading SAE with k={k}, seed={seed} from disk\")\n",
        "            sae = BatchTopKTrainingSAE.load_from_disk(f\"saes_by_k/seed_{seed}/{k}\")\n",
        "        else:\n",
        "            print(f\"Training SAE with k={k}, seed={seed}\")\n",
        "            cfg = BatchTopKTrainingSAEConfig(k=k, d_in=toy_model.embed.weight.shape[0], d_sae=50)\n",
        "            sae = BatchTopKTrainingSAE(cfg)\n",
        "            train_toy_sae(sae, toy_model, generate_batch)\n",
        "            sae.save_model(f\"saes_by_k/seed_{seed}/{k}\")\n",
        "        saes_by_k[k].append(sae)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "d3255bba",
      "metadata": {},
      "outputs": [],
      "source": [
        "from pathlib import Path\n",
        "from typing import Callable\n",
        "from collections import defaultdict\n",
        "import torch\n",
        "import pandas as pd\n",
        "import matplotlib.pyplot as plt\n",
        "import seaborn as sns\n",
        "from tqdm import tqdm\n",
        "from sae_lens import TrainingSAE\n",
        "from sparse_but_wrong.toy_models.plotting import SEABORN_RC_CONTEXT\n",
        "\n",
        "# Assuming ToyModel and other variables are defined elsewhere\n",
        "# from your_module import ToyModel, saes_by_k, toy_model, generate_batch\n",
        "\n",
        "sorted_dec_hidden_pres = {}\n",
        "\n",
        "def calc_l0_dec_thresholds(sae: TrainingSAE, model, generate_batch: Callable[[int], torch.Tensor]) -> list[float]:\n",
        "    inputs = model.embed(generate_batch(100_000))\n",
        "    hidden_pre_dec = (inputs - sae.b_dec) @ sae.W_dec.T\n",
        "    sorted_hidden_pre_dec = hidden_pre_dec.flatten().sort(descending=True).values\n",
        "    k_inds = torch.arange(hidden_pre_dec.shape[-1]) * hidden_pre_dec.shape[0]\n",
        "    return sorted_hidden_pre_dec[k_inds].tolist()\n",
        "\n",
        "thresholds = {}\n",
        "for k, saes in tqdm(saes_by_k.items()):\n",
        "    thresholds[k] = [calc_l0_dec_thresholds(sae, toy_model, generate_batch) for sae in saes]\n",
        "\n",
        "data = []\n",
        "for l0, l0_thresholds in thresholds.items():\n",
        "    rows_by_seed = defaultdict(dict)\n",
        "    for seed, seed_thresholds in enumerate(l0_thresholds):\n",
        "        for i, threshold in enumerate(seed_thresholds):\n",
        "            rows_by_seed[seed][\"l0\"] = l0\n",
        "            rows_by_seed[seed][\"seed\"] = seed\n",
        "            rows_by_seed[seed][f\"l0_dec_{i}\"] = threshold\n",
        "    data.extend(rows_by_seed.values())\n",
        "df = pd.DataFrame(data)\n",
        "\n",
        "Path(\"plots/dpn\").mkdir(parents=True, exist_ok=True)\n",
        "for k in [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25]:\n",
        "    plt.rcParams.update({\"figure.dpi\": 150})\n",
        "    sns.set_theme()\n",
        "    with plt.rc_context(SEABORN_RC_CONTEXT):\n",
        "        plt.figure(figsize=(3, 1.5))\n",
        "        \n",
        "        # Calculate mean and std at each l0 value\n",
        "        grouped = df.groupby(\"l0\")[f\"l0_dec_{k}\"]\n",
        "        mean_values = grouped.mean()\n",
        "        std_values = grouped.std()\n",
        "        l0_values = mean_values.index\n",
        "        \n",
        "        # Plot the mean line\n",
        "        plt.plot(l0_values, mean_values, linewidth=1.0)\n",
        "        \n",
        "        # Add shaded area for 1 standard deviation\n",
        "        plt.fill_between(l0_values, \n",
        "                        mean_values - std_values, \n",
        "                        mean_values + std_values, \n",
        "                        alpha=0.3)\n",
        "        \n",
        "        # Add vertical line at true L0\n",
        "        plt.axvline(x=11, color='red', linestyle='--', linewidth=0.5, alpha=0.7, label='True L0')\n",
        "        \n",
        "        plt.title(f\"$n={k}$ Decoder Projection vs SAE L0\")\n",
        "        plt.xlabel(\"SAE L0\")\n",
        "        plt.ylabel(\"$s_n^{dec}$\")\n",
        "        plt.grid(True, alpha=0.3)\n",
        "        \n",
        "        plt.tight_layout()\n",
        "        plt.savefig(f\"plots/dpn/dp_{k}.pdf\")\n",
        "        plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "31696747",
      "metadata": {},
      "source": [
        "# Sparsity vs Reconstruction Tradeoff\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "d50a77cf",
      "metadata": {},
      "outputs": [],
      "source": [
        "import pandas as pd\n",
        "\n",
        "import seaborn as sns\n",
        "import matplotlib.pyplot as plt\n",
        "from sparse_but_wrong.toy_models.plotting import SEABORN_RC_CONTEXT\n",
        "from sae_lens import BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig\n",
        "from sparse_but_wrong.toy_models.initialization import init_sae_to_match_model\n",
        "\n",
        "# create a ground-truth SAE from the toy model\n",
        "cfg = BatchTopKTrainingSAEConfig(k=11, d_in=toy_model.embed.weight.shape[0], d_sae=50)\n",
        "ground_truth_sae = BatchTopKTrainingSAE(cfg)\n",
        "init_sae_to_match_model(ground_truth_sae, toy_model)\n",
        "\n",
        "test_inputs = toy_model(generate_batch(100_000))\n",
        "\n",
        "def calculate_variance_explained(sae, test_inputs):\n",
        "    sae_outputs = sae(test_inputs)\n",
        "    return 1 - (sae_outputs - test_inputs).pow(2).sum(dim=-1).mean().item() / test_inputs.pow(2).sum(dim=-1).mean().item()\n",
        "\n",
        "\n",
        "data = []\n",
        "for k, saes in saes_by_k.items():\n",
        "    # Calculate variance explained for each SAE (seed) at this k\n",
        "    for seed, sae in enumerate(saes):\n",
        "        data.append({\n",
        "            \"k\": k,\n",
        "            \"seed\": seed,\n",
        "            \"variance_explained\": calculate_variance_explained(sae, test_inputs),\n",
        "            \"variant\": \"Trained SAE\"\n",
        "        })\n",
        "    \n",
        "    # Ground truth SAE (same for all seeds)\n",
        "    ground_truth_sae.activation_fn.k = k # type: ignore[attr-defined]\n",
        "    data.append({\n",
        "        \"k\": k,\n",
        "        \"seed\": 0,  # Ground truth doesn't vary by seed\n",
        "        \"variance_explained\": calculate_variance_explained(ground_truth_sae, test_inputs),\n",
        "        \"variant\": \"Ground-truth SAE\"\n",
        "    })\n",
        "\n",
        "df = pd.DataFrame(data)\n",
        "df.sort_values(by=\"k\", inplace=True)\n",
        "\n",
        "sns.set_theme()\n",
        "plt.rcParams.update({\"figure.dpi\": 150})\n",
        "with plt.rc_context(SEABORN_RC_CONTEXT):\n",
        "    plt.figure(figsize=(3, 2))\n",
        "    \n",
        "    # Plot Trained SAE with mean and std\n",
        "    trained_df = df[df[\"variant\"] == \"Trained SAE\"]\n",
        "    grouped_trained = trained_df.groupby(\"k\")[\"variance_explained\"]\n",
        "    mean_trained = grouped_trained.mean()\n",
        "    std_trained = grouped_trained.std()\n",
        "    k_values = mean_trained.index\n",
        "    \n",
        "    plt.plot(k_values, mean_trained, linewidth=1, label=\"Trained SAE\")\n",
        "    plt.fill_between(k_values, \n",
        "                    mean_trained - std_trained, \n",
        "                    mean_trained + std_trained, \n",
        "                    alpha=0.3)\n",
        "    \n",
        "    # Plot Ground-truth SAE (no std since it's deterministic)\n",
        "    gt_df = df[df[\"variant\"] == \"Ground-truth SAE\"]\n",
        "    gt_grouped = gt_df.groupby(\"k\")[\"variance_explained\"].mean()\n",
        "    plt.plot(gt_grouped.index.to_numpy(), gt_grouped.to_numpy(), linewidth=1, label=\"Ground-truth SAE\")\n",
        "    \n",
        "    # Add vertical line at true L0\n",
        "    plt.axvline(x=11, color='red', linestyle='--', linewidth=0.5, alpha=0.7)\n",
        "    \n",
        "    plt.xlabel(\"L0\")\n",
        "    plt.ylabel(\"Variance explained\")\n",
        "    plt.title(\"Sparsity vs reconstruction tradeoff\")\n",
        "    plt.legend()\n",
        "    plt.tight_layout()\n",
        "    plt.savefig(\"plots/sparsity_vs_reconstruction_tradeoff.pdf\")\n",
        "    plt.show()\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "8024a7f0",
      "metadata": {},
      "outputs": [],
      "source": [
        "from sparse_but_wrong.toy_models.plotting import plot_sae_feat_cos_sims, plot_sae_feat_cos_sims_seaborn\n",
        "\n",
        "plot_sae_feat_cos_sims(ground_truth_sae, toy_model, \"Ground truth SAE\", reorder_features=True, dtick=5)\n",
        "plot_sae_feat_cos_sims_seaborn(ground_truth_sae, toy_model, title=\"Ground truth SAE\", reorder_features=True, decoder_only=True, width=5, height=2, dtick=5, decoder_title=None, save_path=\"sae_ground_truth_decoder_cos_sims.pdf\")\n",
        "plot_sae_feat_cos_sims_seaborn(saes_by_k[1][0], toy_model, title=\"L0=1 trained SAE\", reorder_features=True, decoder_only=True, width=5, height=2, dtick=5, decoder_title=None, save_path=\"plots/toy_l0/sae_l0_1_decoder_cos_sims.pdf\")\n",
        "plot_sae_feat_cos_sims_seaborn(saes_by_k[2][0], toy_model, title=\"L0=2 trained SAE\", reorder_features=True, decoder_only=True, width=5, height=2, dtick=5, decoder_title=None, save_path=\"plots/toy_l0/sae_l0_2_decoder_cos_sims.pdf\")\n",
        "plot_sae_feat_cos_sims_seaborn(saes_by_k[11][0], toy_model, title=\"L0=11 trained SAE\", reorder_features=True, decoder_only=True, width=5, height=2, dtick=5, decoder_title=None, save_path=\"plots/toy_l0/sae_l0_11_decoder_cos_sims.pdf\")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "ac7a61a6",
      "metadata": {},
      "source": [
        "# Transitioning to the correct L0\n",
        "\n",
        "Does it make a difference if we start at too low an L0 and transition to the correct L0, vs starting too high?"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "267fe73b",
      "metadata": {},
      "outputs": [],
      "source": [
        "from sparse_but_wrong.enchanced_batch_topk_sae import (\n",
        "    EnchancedBatchTopKTrainingSAE,\n",
        "    EnhancedBatchTopKTrainingSAEConfig,\n",
        ")\n",
        "from sparse_but_wrong.toy_models.train_toy_sae import train_toy_sae\n",
        "\n",
        "cfg = EnhancedBatchTopKTrainingSAEConfig(\n",
        "    k=11,\n",
        "    d_in=toy_model.embed.weight.shape[0],\n",
        "    d_sae=50,\n",
        "    initial_k=20,\n",
        "    transition_k_duration_steps=8_000,\n",
        "    transition_k_start_step=3_000,\n",
        ")\n",
        "sae_transition_down = EnchancedBatchTopKTrainingSAE(cfg)\n",
        "\n",
        "train_toy_sae(sae_transition_down, toy_model, generate_batch)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "6bd0253a",
      "metadata": {},
      "outputs": [],
      "source": [
        "from sparse_but_wrong.toy_models.plotting import plot_sae_feat_cos_sims, plot_sae_feat_cos_sims_seaborn\n",
        "\n",
        "plot_sae_feat_cos_sims(sae_transition_down, toy_model, \"SAE L0 Decrease 20 to 11\", reorder_features=True, dtick=5)\n",
        "plot_sae_feat_cos_sims_seaborn(sae_transition_down, toy_model, title=\"SAE L0 Decrease 20 $\\\\to$ 11\", reorder_features=True, decoder_only=True, width=5, height=2, dtick=5, decoder_title=None, save_path=\"plots/toy_l0/sae_l0_20_to_11_decoder_cos_sims.pdf\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "eeca81bc",
      "metadata": {},
      "outputs": [],
      "source": [
        "from sparse_but_wrong.enchanced_batch_topk_sae import (\n",
        "    EnchancedBatchTopKTrainingSAE,\n",
        "    EnhancedBatchTopKTrainingSAEConfig,\n",
        ")\n",
        "from sparse_but_wrong.toy_models.train_toy_sae import train_toy_sae\n",
        "\n",
        "cfg = EnhancedBatchTopKTrainingSAEConfig(\n",
        "    k=11,\n",
        "    d_in=toy_model.embed.weight.shape[0],\n",
        "    d_sae=50,\n",
        "    initial_k=2,\n",
        "    transition_k_duration_steps=25_000,\n",
        "    transition_k_start_step=0,\n",
        ")\n",
        "sae_transition_up = EnchancedBatchTopKTrainingSAE(cfg)\n",
        "\n",
        "train_toy_sae(sae_transition_up, toy_model, generate_batch)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "7f7f8f63",
      "metadata": {},
      "outputs": [],
      "source": [
        "from sparse_but_wrong.toy_models.plotting import plot_sae_feat_cos_sims, plot_sae_feat_cos_sims_seaborn\n",
        "\n",
        "plot_sae_feat_cos_sims(sae_transition_up, toy_model, \"SAE L0 Increase 2 to 11\", reorder_features=True, dtick=5)\n",
        "plot_sae_feat_cos_sims_seaborn(sae_transition_up, toy_model, title=\"SAE L0 Increase 2 $\\\\to$ 11\", reorder_features=True, decoder_only=True, width=5, height=2, dtick=5, decoder_title=None, save_path=\"plots/toy_l0/sae_l0_2_to_11_decoder_cos_sims.pdf\")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "43e1bcd5",
      "metadata": {},
      "source": [
        "# Alternate metric: pairwise cosine similarity"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "cda805fa",
      "metadata": {},
      "outputs": [],
      "source": [
        "from pathlib import Path\n",
        "import torch\n",
        "import pandas as pd\n",
        "import matplotlib.pyplot as plt\n",
        "import seaborn as sns\n",
        "from sparse_but_wrong.toy_models.plotting import SEABORN_RC_CONTEXT\n",
        "\n",
        "\n",
        "@torch.no_grad()\n",
        "def calc_pairwise_cosine_similarity_score(\n",
        "    weights: torch.Tensor,\n",
        ") -> float:\n",
        "    weights_normalized = torch.nn.functional.normalize(weights, p=2, dim=1)\n",
        "    similarity_matrix = torch.mm(weights_normalized, weights_normalized.T)\n",
        "    d_sae = weights.shape[0]\n",
        "    upper_triangular_mask = torch.triu(\n",
        "        torch.ones(d_sae, d_sae, device=weights.device), diagonal=1\n",
        "    ).bool()\n",
        "    similarities = similarity_matrix[upper_triangular_mask]\n",
        "    return similarities.abs().mean().item()\n",
        "\n",
        "\n",
        "data = []\n",
        "for l0, saes in saes_by_k.items():\n",
        "    for seed, sae in enumerate(saes):\n",
        "        row = {\n",
        "            \"l0\": l0,\n",
        "            \"seed\": seed,\n",
        "            \"similarity_cosine_dec\": calc_pairwise_cosine_similarity_score(sae.W_dec),\n",
        "        }\n",
        "        data.append(row)\n",
        "df = pd.DataFrame(data)\n",
        "\n",
        "Path(\"plots/toy_similarity_metrics\").mkdir(parents=True, exist_ok=True)\n",
        "metrics = [\n",
        "    (\"similarity_cosine_dec\", \"Decoder Pairwise Cosine Similarity\", \"Cosine Similarity\"),\n",
        "]\n",
        "\n",
        "plt.rcParams.update({\"figure.dpi\": 150})\n",
        "sns.set_theme()\n",
        "\n",
        "for metric_col, title, ylabel in metrics:\n",
        "    with plt.rc_context(SEABORN_RC_CONTEXT):\n",
        "        plt.figure(figsize=(3, 1.5))\n",
        "        \n",
        "        # Calculate mean and std at each l0 value\n",
        "        grouped = df.groupby(\"l0\")[metric_col]\n",
        "        mean_values = grouped.mean()\n",
        "        std_values = grouped.std()\n",
        "        l0_values = mean_values.index\n",
        "        \n",
        "        # Plot the mean line\n",
        "        plt.plot(l0_values, mean_values, linewidth=1.0)\n",
        "        \n",
        "        # Add shaded area for 1 standard deviation\n",
        "        plt.fill_between(l0_values, \n",
        "                        mean_values - std_values, \n",
        "                        mean_values + std_values, \n",
        "                        alpha=0.3)\n",
        "        \n",
        "        # Add vertical line at true L0\n",
        "        plt.axvline(x=11, color='red', linestyle='--', linewidth=0.5, alpha=0.7, label='True L0')\n",
        "        \n",
        "        plt.title(title)\n",
        "        plt.xlabel(\"SAE L0\")\n",
        "        plt.ylabel(ylabel)\n",
        "        plt.grid(True, alpha=0.3)\n",
        "        \n",
        "        plt.tight_layout()\n",
        "        \n",
        "        # Save plot\n",
        "        filename = metric_col.replace(\"similarity_\", \"\").replace(\"_\", \"_\")\n",
        "        plt.savefig(f\"plots/toy_similarity_metrics/{filename}_vs_l0.pdf\")\n",
        "        plt.show()\n",
        "\n"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": ".venv",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.4"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}
