{
  "cells": [
    {
      "cell_type": "markdown",
      "id": "45e8a0ab",
      "metadata": {},
      "source": [
        "# Small toy model experiments\n",
        "\n",
        "In this notebook, we explore the effect of L0 on SAEs when there are correlated features with a tiny toy model where everything is easily visible."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "8ef804a4",
      "metadata": {},
      "source": [
        "# Positive correlations\n",
        "\n",
        "We'll start with a simple toy model with 5 features, and a correlation matrix where all features are positively correlated with feature 0. All other features are uncorrelated with each other."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "8f1a13f9",
      "metadata": {},
      "outputs": [],
      "source": [
        "from functools import partial\n",
        "from pathlib import Path\n",
        "import torch\n",
        "from tqdm import tqdm\n",
        "\n",
        "from sparse_but_wrong.toy_models.get_training_batch import  get_training_batch\n",
        "from sparse_but_wrong.toy_models.toy_model import ToyModel\n",
        "from sparse_but_wrong.toy_models.plotting import plot_correlation_matrix\n",
        "from sparse_but_wrong.util import DEFAULT_DEVICE\n",
        "\n",
        "tqdm._instances.clear()  # type: ignore\n",
        "\n",
        "# Set up the topy model to have L0=2\n",
        "feat_probs = torch.ones(5) * 2 / 5\n",
        "pos_correlations = torch.zeros(5, 5)\n",
        "pos_correlations[:, 0] = 0.4\n",
        "pos_correlations[0, :] = 0.4\n",
        "pos_correlations.fill_diagonal_(1.0)\n",
        "\n",
        "if Path(\"small_toy_model.pt\").exists():\n",
        "    print(\"Loading toy model from disk\")\n",
        "    toy_model = torch.load(\"small_toy_model.pt\", weights_only=False)\n",
        "else:\n",
        "    toy_model = ToyModel(num_feats=5, hidden_dim=20).to(DEFAULT_DEVICE)\n",
        "    torch.save(toy_model, \"small_toy_model.pt\")\n",
        "\n",
        "\n",
        "generate_batch_pos = partial(\n",
        "    get_training_batch,\n",
        "    firing_probabilities=feat_probs,\n",
        "    std_firing_magnitudes=torch.ones_like(feat_probs) * 0.15,\n",
        "    correlation_matrix=pos_correlations,\n",
        ")\n",
        "\n",
        "\n",
        "plot_correlation_matrix(pos_correlations, save_path=\"plots/toy_setup_small/positive_correlation_matrix.pdf\")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "8aef584d",
      "metadata": {},
      "source": [
        "## Finding the True L0\n",
        "\n",
        "Next, we calculate the true L0 for this dataset (spoiler: it's ~2)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "2d2f1ea7",
      "metadata": {},
      "outputs": [],
      "source": [
        "sample = generate_batch_pos(100_000)\n",
        "true_l0 = (sample > 0).float().sum(dim=-1).mean()\n",
        "print(f\"True L0: {true_l0}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "04a8470a",
      "metadata": {},
      "source": [
        "## Training an SAE with the right L0\n",
        "\n",
        "Next, let's train an SAE with the correct L0 and verify it can perfectly recover the underlying true features"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "8e8bd749",
      "metadata": {},
      "outputs": [],
      "source": [
        "from sae_lens import BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig\n",
        "from sparse_but_wrong.toy_models.train_toy_sae import train_toy_sae\n",
        "\n",
        "cfg = BatchTopKTrainingSAEConfig(k=2, d_in=toy_model.embed.weight.shape[0], d_sae=5)\n",
        "sae_full = BatchTopKTrainingSAE(cfg)\n",
        "\n",
        "# NOTE: occasionaly this gets stuck in poor local minima. If this happens, try rerunning and it should converge properly.\n",
        "train_toy_sae(sae_full, toy_model, generate_batch_pos)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "af294427",
      "metadata": {},
      "outputs": [],
      "source": [
        "from sparse_but_wrong.toy_models.plotting import plot_sae_feat_cos_sims, plot_sae_feat_cos_sims_seaborn\n",
        "\n",
        "plot_sae_feat_cos_sims(sae_full, toy_model, \"SAE L0 = True L0\", reorder_features=True, dtick=5)\n",
        "plot_sae_feat_cos_sims_seaborn(sae_full, toy_model, title=\"SAE L0 = True L0\", reorder_features=True, decoder_only=True, width=5, height=2, decoder_title=None, save_path=\"plots/toy_l0_small/pos_corr_sae_l0_eq_true_l0_decoder_cos_sims.pdf\")\n",
        "plot_sae_feat_cos_sims_seaborn(sae_full, toy_model, title=\"SAE L0 = True L0\", reorder_features=True, decoder_only=True, width=5, height=2, decoder_title=None, save_path=\"plots/toy_l0_small/pos_corr_sae_l0_eq_true_l0_decoder_cos_sims.png\")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "3dc66a9c",
      "metadata": {},
      "source": [
        "As we can see above, the SAE has learned the correct features despite the feature correlations."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "dde3bc76",
      "metadata": {},
      "source": [
        "## What if reduce the L0 slightly below the number of true features?\n",
        "\n",
        "If we lower the L0, then then SAE will need to reconstruct the input with less latents than the number of true features required. We'll reduce the SAE L0 by 10%, from 2 to 1.8 (BatchTopK SAEs have no problem with fractional L0s, which is nice for our experiments). We'll further initialize the SAE to the ground truth correct solution, so we can be certain that anything that results is a result of gradient pressure rather than just a poor local minimum."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "337325f4",
      "metadata": {},
      "outputs": [],
      "source": [
        "from sae_lens import BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig\n",
        "from sparse_but_wrong.toy_models.train_toy_sae import train_toy_sae\n",
        "from sparse_but_wrong.toy_models.initialization import init_sae_to_match_model\n",
        "\n",
        "cfg = BatchTopKTrainingSAEConfig(k=1.8, d_in=toy_model.embed.weight.shape[0], d_sae=5)\n",
        "sae_narrow = BatchTopKTrainingSAE(cfg)\n",
        "\n",
        "init_sae_to_match_model(sae_narrow, toy_model)\n",
        "\n",
        "train_toy_sae(sae_narrow, toy_model, generate_batch_pos)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "cece73e4",
      "metadata": {},
      "outputs": [],
      "source": [
        "import plotly.express as px\n",
        "from sparse_but_wrong.toy_models.plotting import plot_sae_feat_cos_sims, plot_sae_feat_cos_sims_seaborn\n",
        "\n",
        "plot_sae_feat_cos_sims(sae_narrow, toy_model, \"SAE L0 < True L0\", reorder_features=True)\n",
        "px.imshow(pos_correlations, width=400, height=400, color_continuous_scale=\"RdBu\", zmin=-1, zmax=1, title=\"Feature correlation matrix\", origin=\"lower\").show()\n",
        "plot_sae_feat_cos_sims_seaborn(sae_narrow, toy_model, title=\"SAE L0 $<$ True L0\", reorder_features=True, decoder_only=True, width=5, height=2, decoder_title=None, save_path=\"plots/toy_l0_small/pos_corr_sae_l0_lt_true_l0_decoder_cos_sims.pdf\")\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "00ea495b",
      "metadata": {},
      "source": [
        "We see that all SAE latents contain a strong positive component of feature 0, but no component of any other feautre, exactly matching the correlation matrix. Clearly the correlations are what is causing the SAE to mix feature components together."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "a0b7a631",
      "metadata": {},
      "source": [
        "# Negative correlations\n",
        "\n",
        "Next, we'll change the correlation matrix so all features are negatively correlated with feature 0 instead of positively correlated. We'll keep the same toy model as before, we're just changing feature firing correlations."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "8670802c",
      "metadata": {},
      "outputs": [],
      "source": [
        "from functools import partial\n",
        "import torch\n",
        "\n",
        "from sparse_but_wrong.toy_models.get_training_batch import  get_training_batch\n",
        "from sparse_but_wrong.toy_models.plotting import plot_correlation_matrix\n",
        "\n",
        "# Set up the topy model to have L0=2\n",
        "feat_probs = torch.ones(5) * 2 / 5\n",
        "neg_correlations = torch.zeros(5, 5)\n",
        "neg_correlations[:, 0] = -0.4\n",
        "neg_correlations[0, :] = -0.4\n",
        "neg_correlations.fill_diagonal_(1.0)\n",
        "\n",
        "\n",
        "generate_batch_neg = partial(\n",
        "    get_training_batch,\n",
        "    firing_probabilities=feat_probs,\n",
        "    std_firing_magnitudes=torch.ones_like(feat_probs) * 0.15,\n",
        "    correlation_matrix=neg_correlations,\n",
        ")\n",
        "\n",
        "plot_correlation_matrix(neg_correlations, save_path=\"plots/toy_setup_small/negative_correlation_matrix.pdf\")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "167b875f",
      "metadata": {},
      "source": [
        "## Finding the True L0\n",
        "\n",
        "Let's double-check that the true L0 is still 2, as before."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "048132c2",
      "metadata": {},
      "outputs": [],
      "source": [
        "sample = generate_batch_neg(100_000)\n",
        "true_l0 = (sample > 0).float().sum(dim=-1).mean()\n",
        "print(f\"True L0: {true_l0}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "88d79850",
      "metadata": {},
      "source": [
        "## Training an SAE with the right L0\n",
        "\n",
        "Next, let's train an SAE with the correct L0 and verify it can perfectly recover the underlying true features"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "a866615e",
      "metadata": {},
      "outputs": [],
      "source": [
        "from sae_lens import BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig\n",
        "from sparse_but_wrong.toy_models.train_toy_sae import train_toy_sae\n",
        "\n",
        "cfg = BatchTopKTrainingSAEConfig(k=2, d_in=toy_model.embed.weight.shape[0], d_sae=5)\n",
        "sae_full = BatchTopKTrainingSAE(cfg)\n",
        "\n",
        "# NOTE: occasionaly this gets stuck in poor local minima. If this happens, try rerunning and it should converge properly.\n",
        "train_toy_sae(sae_full, toy_model, generate_batch_neg)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "9393eee1",
      "metadata": {},
      "outputs": [],
      "source": [
        "from sparse_but_wrong.toy_models.plotting import plot_sae_feat_cos_sims, plot_sae_feat_cos_sims_seaborn\n",
        "\n",
        "plot_sae_feat_cos_sims(sae_full, toy_model, \"SAE L0 = True L0\", reorder_features=True, dtick=5)\n",
        "plot_sae_feat_cos_sims_seaborn(sae_full, toy_model, title=\"SAE L0 = True L0\", reorder_features=True, decoder_only=True, width=5, height=2, decoder_title=None, save_path=\"plots/toy_l0_small/neg_corr_sae_l0_eq_true_l0_decoder_cos_sims.pdf\")\n",
        "plot_sae_feat_cos_sims_seaborn(sae_full, toy_model, title=\"SAE L0 = True L0\", reorder_features=True, decoder_only=True, width=5, height=2, decoder_title=None, save_path=\"plots/toy_l0_small/neg_corr_sae_l0_eq_true_l0_decoder_cos_sims.png\")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "9562a826",
      "metadata": {},
      "source": [
        "As we can see above, the SAE has learned the correct features despite the feature correlations."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "2fc547f4",
      "metadata": {},
      "source": [
        "## What if reduce the L0 slightly below the number of true features?\n",
        "\n",
        "If we lower the L0, then then SAE will need to reconstruct the input with less latents than the number of true features required. We'll reduce the SAE L0 by 10%, from 2 to 1.8 (BatchTopK SAEs have no problem with fractional L0s, which is nice for our experiments). We'll further initialize the SAE to the ground truth correct solution, so we can be certain that anything that results is a result of gradient pressure rather than just a poor local minimum."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "bde8f203",
      "metadata": {},
      "outputs": [],
      "source": [
        "from sae_lens import BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig\n",
        "from sparse_but_wrong.toy_models.train_toy_sae import train_toy_sae\n",
        "from sparse_but_wrong.toy_models.initialization import init_sae_to_match_model\n",
        "\n",
        "cfg = BatchTopKTrainingSAEConfig(k=1.8, d_in=toy_model.embed.weight.shape[0], d_sae=5)\n",
        "sae_narrow = BatchTopKTrainingSAE(cfg)\n",
        "\n",
        "init_sae_to_match_model(sae_narrow, toy_model)\n",
        "\n",
        "train_toy_sae(sae_narrow, toy_model, generate_batch_neg)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "479345a3",
      "metadata": {},
      "outputs": [],
      "source": [
        "import plotly.express as px\n",
        "from sparse_but_wrong.toy_models.plotting import plot_sae_feat_cos_sims, plot_sae_feat_cos_sims_seaborn\n",
        "\n",
        "plot_sae_feat_cos_sims(sae_narrow, toy_model, \"SAE L0 < True L0\", reorder_features=True)\n",
        "px.imshow(neg_correlations, width=400, height=400, color_continuous_scale=\"RdBu\", zmin=-1, zmax=1, title=\"Feature correlation matrix\", origin=\"lower\").show()\n",
        "plot_sae_feat_cos_sims_seaborn(sae_narrow, toy_model, title=\"SAE L0 $<$ True L0\", reorder_features=True, decoder_only=True, width=5, height=2, decoder_title=None, save_path=\"plots/toy_l0_small/neg_corr_sae_l0_lt_true_l0_decoder_cos_sims.pdf\")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "a688d054",
      "metadata": {},
      "source": [
        "# How does the degree of mixing change with amount of correlation?\n",
        "\n",
        "We'll see what happens when we vary the correlation with feature 0 from -0.5 to 0.5"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "e9195ba8",
      "metadata": {},
      "outputs": [],
      "source": [
        "from collections import defaultdict\n",
        "from functools import partial\n",
        "from pathlib import Path\n",
        "import torch\n",
        "\n",
        "from sparse_but_wrong.toy_models.get_training_batch import  get_training_batch\n",
        "\n",
        "\n",
        "def get_generator_for_correlation(correlation_strength: float):\n",
        "    # Set up the topy model to have L0=2\n",
        "    feat_probs = torch.ones(5) * 2 / 5\n",
        "    correlations = torch.zeros(5, 5)\n",
        "    correlations[:, 0] = correlation_strength\n",
        "    correlations[0, :] = correlation_strength\n",
        "    correlations.fill_diagonal_(1.0)\n",
        "\n",
        "    return partial(\n",
        "        get_training_batch,\n",
        "        firing_probabilities=feat_probs,\n",
        "        std_firing_magnitudes=torch.ones_like(feat_probs) * 0.15,\n",
        "        correlation_matrix=correlations,\n",
        "    )\n",
        "\n",
        "\n",
        "saes_by_correlation = defaultdict(list)\n",
        "for seed in [0, 1, 2, 3, 4]:\n",
        "    for corr in [-0.5, -0.4, -0.3, -0.2, -0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5]:\n",
        "        Path(f\"saes_by_correlation/seed_{seed}\").mkdir(parents=True, exist_ok=True)\n",
        "        if Path(f\"saes_by_correlation/seed_{seed}/{corr}\").exists():\n",
        "            print(f\"Loading SAE with corr={corr}, seed={seed} from disk\")\n",
        "            sae = BatchTopKTrainingSAE.load_from_disk(f\"saes_by_correlation/seed_{seed}/{corr}\")\n",
        "        else:\n",
        "            print(f\"Training SAE with corr={corr}, seed={seed}\")\n",
        "            cfg = BatchTopKTrainingSAEConfig(k=1.8, d_in=toy_model.embed.weight.shape[0], d_sae=5)\n",
        "            sae = BatchTopKTrainingSAE(cfg)\n",
        "            init_sae_to_match_model(sae, toy_model)\n",
        "            train_toy_sae(sae, toy_model, get_generator_for_correlation(corr))\n",
        "            sae.W_dec.data = sae.W_dec.data.contiguous()\n",
        "            sae.save_model(f\"saes_by_correlation/seed_{seed}/{corr}\")\n",
        "        saes_by_correlation[corr].append(sae)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "84002d62",
      "metadata": {},
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "import seaborn as sns\n",
        "import numpy as np\n",
        "from pathlib import Path\n",
        "from sparse_but_wrong.util import cos_sims\n",
        "from sparse_but_wrong.toy_models.plotting import SEABORN_RC_CONTEXT\n",
        "\n",
        "# For each correlation value, calculate mean cosine similarity of latents 1-3 with feature 0\n",
        "correlation_values = sorted(saes_by_correlation.keys())\n",
        "mean_cos_sims_per_corr = []\n",
        "std_cos_sims_per_corr = []\n",
        "\n",
        "for corr in correlation_values:\n",
        "    saes = saes_by_correlation[corr]\n",
        "    cos_sims_for_seeds = []\n",
        "    \n",
        "    for sae in saes:\n",
        "        # Calculate cosine similarities between SAE decoder and true features\n",
        "        dec_cos_sims = cos_sims(sae.W_dec.T, toy_model.embed.weight)\n",
        "        \n",
        "        # Get cosine similarities for latents 1, 2, 3 with feature 0\n",
        "        latents_1_3_with_feat_0 = dec_cos_sims[[1, 2, 3], 0]\n",
        "        \n",
        "        # Take the mean\n",
        "        mean_cos_sim = latents_1_3_with_feat_0.mean().item()\n",
        "        cos_sims_for_seeds.append(mean_cos_sim)\n",
        "    \n",
        "    # Average across seeds\n",
        "    mean_cos_sims_per_corr.append(np.mean(cos_sims_for_seeds))\n",
        "    std_cos_sims_per_corr.append(np.std(cos_sims_for_seeds))\n",
        "\n",
        "# Convert to numpy arrays for easier manipulation\n",
        "mean_cos_sims_per_corr = np.array(mean_cos_sims_per_corr)\n",
        "std_cos_sims_per_corr = np.array(std_cos_sims_per_corr)\n",
        "\n",
        "# Plot\n",
        "Path(\"plots/toy_l0_small\").mkdir(parents=True, exist_ok=True)\n",
        "plt.rcParams.update({\"figure.dpi\": 150})\n",
        "sns.set_theme()\n",
        "with plt.rc_context(SEABORN_RC_CONTEXT):\n",
        "    plt.figure(figsize=(3, 1.5))\n",
        "    \n",
        "    # Plot the mean line\n",
        "    plt.plot(correlation_values, mean_cos_sims_per_corr, linewidth=1.0)\n",
        "    \n",
        "    # Add shaded area for 1 standard deviation\n",
        "    plt.fill_between(correlation_values, \n",
        "                    mean_cos_sims_per_corr - std_cos_sims_per_corr, \n",
        "                    mean_cos_sims_per_corr + std_cos_sims_per_corr, \n",
        "                    alpha=0.3)\n",
        "    \n",
        "    plt.title(\"Feature mixing vs correlation\")\n",
        "    plt.xlabel(\"Correlation between $f_0$ and other features\")\n",
        "    plt.ylabel(\"Mean cos sim\")\n",
        "    plt.grid(True, alpha=0.3)\n",
        "    \n",
        "    plt.tight_layout()\n",
        "    plt.savefig('plots/toy_l0_small/mean_cos_sim_vs_correlation.pdf')\n",
        "    plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "0e3e7701",
      "metadata": {},
      "outputs": [],
      "source": [
        "for corr in [-0.5, 0.0, 0.5]:\n",
        "    plot_sae_feat_cos_sims_seaborn(\n",
        "        saes_by_correlation[corr][0],\n",
        "        toy_model,\n",
        "        title=f\"correlation {corr}\",\n",
        "        reorder_features=True,\n",
        "        decoder_only=True,\n",
        "        width=5,\n",
        "        height=2,\n",
        "        decoder_title=None,\n",
        "        save_path=f\"plots/toy_l0_small/correlation_{corr}.pdf\",\n",
        "    )"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "da9ef8a0",
      "metadata": {},
      "source": [
        "# Changing L0 while holding correlation constant\n",
        "\n",
        "Next, we'll see what happens when we change the L0 while holding the correlation constant. First, we'll train some SAEs with correlation -0.4 and 0.4, as in our earlier experiments, but vary the vary L0 from 1.7 to 2.0 (the true L0 of the toy model is 2.0). We can't go much lower than 1.7 without fully breaking the SAE, as dropping L0 too far causes the SAE latents become so distorted that they bear almost no resemblance at all to the true features, making a clean measurement of correlation between latents 1 - 3 and feature 0 difficult."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "debff05f",
      "metadata": {},
      "outputs": [],
      "source": [
        "from collections import defaultdict\n",
        "from pathlib import Path\n",
        "\n",
        "saes_by_l0 = {\n",
        "    \"pos\": defaultdict(list),\n",
        "    \"neg\": defaultdict(list),\n",
        "}\n",
        "POS_CORR = 0.4\n",
        "NEG_CORR = -0.4\n",
        "\n",
        "for corr_type, corr in [(\"pos\", POS_CORR), (\"neg\", NEG_CORR)]:\n",
        "    for seed in [0, 1, 2, 3, 4]:\n",
        "        for l0 in [1.7, 1.8, 1.9, 2.0]:\n",
        "            Path(f\"saes_by_l0/{corr_type}/seed_{seed}\").mkdir(parents=True, exist_ok=True)\n",
        "            if Path(f\"saes_by_l0/{corr_type}/seed_{seed}/{l0}\").exists():\n",
        "                print(f\"Loading SAE with corr={corr}, l0={l0}, seed={seed} from disk\")\n",
        "                sae = BatchTopKTrainingSAE.load_from_disk(f\"saes_by_l0/{corr_type}/seed_{seed}/{l0}\")\n",
        "            else:\n",
        "                print(f\"Training SAE with l0={l0}, seed={seed}\")\n",
        "                cfg = BatchTopKTrainingSAEConfig(k=l0, d_in=toy_model.embed.weight.shape[0], d_sae=5)\n",
        "                sae = BatchTopKTrainingSAE(cfg)\n",
        "                init_sae_to_match_model(sae, toy_model)\n",
        "                train_toy_sae(sae, toy_model, get_generator_for_correlation(corr))\n",
        "                sae.W_dec.data = sae.W_dec.data.contiguous()\n",
        "                sae.save_model(f\"saes_by_l0/{corr_type}/seed_{seed}/{l0}\")\n",
        "            saes_by_l0[corr_type][l0].append(sae)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "6554cb3f",
      "metadata": {},
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "import seaborn as sns\n",
        "import numpy as np\n",
        "from pathlib import Path\n",
        "from sparse_but_wrong.util import cos_sims\n",
        "from sparse_but_wrong.toy_models.plotting import SEABORN_RC_CONTEXT\n",
        "\n",
        "Path(\"plots/toy_l0_small\").mkdir(parents=True, exist_ok=True)\n",
        "plt.rcParams.update({\"figure.dpi\": 150})\n",
        "sns.set_theme()\n",
        "\n",
        "# Plot both positive and negative correlation variants\n",
        "for corr_type, corr_label, corr_value in [(\"pos\", \"positive\", 0.4), (\"neg\", \"negative\", -0.4)]:\n",
        "    # For each L0 value, calculate mean cosine similarity of latents 1-3 with feature 0\n",
        "    l0_values = sorted(saes_by_l0[corr_type].keys())\n",
        "    mean_cos_sims_per_l0 = []\n",
        "    std_cos_sims_per_l0 = []\n",
        "    \n",
        "    for l0 in l0_values:\n",
        "        saes = saes_by_l0[corr_type][l0]\n",
        "        cos_sims_for_seeds = []\n",
        "        \n",
        "        for sae in saes:\n",
        "            # Calculate cosine similarities between SAE decoder and true features\n",
        "            dec_cos_sims = cos_sims(sae.W_dec.T, toy_model.embed.weight)\n",
        "            \n",
        "            # Get cosine similarities for latents 1, 2, 3 with feature 0\n",
        "            latents_1_3_with_feat_0 = dec_cos_sims[[1, 2, 3], 0]\n",
        "            \n",
        "            # Take the mean\n",
        "            mean_cos_sim = latents_1_3_with_feat_0.mean().item()\n",
        "            cos_sims_for_seeds.append(mean_cos_sim)\n",
        "        \n",
        "        # Average across seeds\n",
        "        mean_cos_sims_per_l0.append(np.mean(cos_sims_for_seeds))\n",
        "        std_cos_sims_per_l0.append(np.std(cos_sims_for_seeds))\n",
        "    \n",
        "    # Convert to numpy arrays for easier manipulation\n",
        "    mean_cos_sims_per_l0 = np.array(mean_cos_sims_per_l0)\n",
        "    std_cos_sims_per_l0 = np.array(std_cos_sims_per_l0)\n",
        "    \n",
        "    # Plot\n",
        "    with plt.rc_context(SEABORN_RC_CONTEXT):\n",
        "        plt.figure(figsize=(3, 1.5))\n",
        "        \n",
        "        # Plot the mean line\n",
        "        plt.plot(l0_values, mean_cos_sims_per_l0, linewidth=1.0)\n",
        "        \n",
        "        # Add shaded area for 1 standard deviation\n",
        "        plt.fill_between(l0_values, \n",
        "                        mean_cos_sims_per_l0 - std_cos_sims_per_l0, \n",
        "                        mean_cos_sims_per_l0 + std_cos_sims_per_l0, \n",
        "                        alpha=0.3)\n",
        "        \n",
        "        plt.title(f\"Feature mixing vs L0 ({corr_label} correlation)\")\n",
        "        plt.xlabel(\"SAE L0\")\n",
        "        plt.ylabel(\"Mean cos sim\")\n",
        "        plt.grid(True, alpha=0.3)\n",
        "        \n",
        "        plt.tight_layout()\n",
        "        plt.savefig(f'plots/toy_l0_small/mean_cos_sim_vs_l0_{corr_type}.pdf')\n",
        "        plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "c89d7271",
      "metadata": {},
      "source": [
        "# Changing firing frequency\n",
        "\n",
        "Next, we study the effect of changing the relative firing frequency of feature 0 relative to the other features. We'll keep the correlation constant at 0.4 and -0.4, and keep the L0 constant at 2.0."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "ca540802",
      "metadata": {},
      "outputs": [],
      "source": [
        "from collections import defaultdict\n",
        "from pathlib import Path\n",
        "from sae_lens import BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig\n",
        "from sparse_but_wrong.toy_models.train_toy_sae import train_toy_sae\n",
        "from sparse_but_wrong.toy_models.initialization import init_sae_to_match_model\n",
        "\n",
        "\n",
        "saes_by_freq = {\n",
        "    \"pos\": defaultdict(list),\n",
        "    \"neg\": defaultdict(list),\n",
        "}\n",
        "POS_CORR = 0.4\n",
        "NEG_CORR = -0.4\n",
        "L0 = 1.8\n",
        "\n",
        "def get_generator_for_freq_ratio_and_corr(freq_ratio: float, corr: float = 1.0, target_l0: float = 2.0):\n",
        "    f0_prob = target_l0 * freq_ratio / (freq_ratio + 4)\n",
        "    other_probs = target_l0 / (freq_ratio + 4)\n",
        "    feat_probs = torch.ones(5) * other_probs\n",
        "    feat_probs[0] = f0_prob\n",
        "    correlations = torch.zeros(5, 5)\n",
        "    correlations[:, 0] = corr\n",
        "    correlations[0, :] = corr\n",
        "    correlations.fill_diagonal_(1.0)\n",
        "    return partial(get_training_batch, firing_probabilities=feat_probs, std_firing_magnitudes=torch.ones_like(feat_probs) * 0.15, correlation_matrix=correlations)\n",
        "\n",
        "for corr_type, corr in [(\"pos\", POS_CORR), (\"neg\", NEG_CORR)]:\n",
        "    for seed in [0]:\n",
        "        for freq_ratio in [0.1, 0.2, 1.0, 4.0]:\n",
        "            Path(f\"saes_by_freq/{corr_type}/seed_{seed}\").mkdir(parents=True, exist_ok=True)\n",
        "            if Path(f\"saes_by_freq/{corr_type}/seed_{seed}/{freq_ratio}\").exists():\n",
        "                print(f\"Loading SAE with corr={corr}, freq_ratio={freq_ratio}, seed={seed} from disk\")\n",
        "                sae = BatchTopKTrainingSAE.load_from_disk(f\"saes_by_freq/{corr_type}/seed_{seed}/{freq_ratio}\")\n",
        "            else:\n",
        "                print(f\"Training SAE with l0={L0}, freq_ratio={freq_ratio}, seed={seed}\")\n",
        "                cfg = BatchTopKTrainingSAEConfig(k=L0, d_in=toy_model.embed.weight.shape[0], d_sae=5)\n",
        "                sae = BatchTopKTrainingSAE(cfg)\n",
        "                init_sae_to_match_model(sae, toy_model)\n",
        "                train_toy_sae(sae, toy_model, get_generator_for_freq_ratio_and_corr(corr=corr, freq_ratio=freq_ratio, target_l0=2.0))\n",
        "                sae.W_dec.data = sae.W_dec.data.contiguous()\n",
        "                sae.save_model(f\"saes_by_freq/{corr_type}/seed_{seed}/{freq_ratio}\")\n",
        "            saes_by_freq[corr_type][freq_ratio].append(sae)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "a9ce7697",
      "metadata": {},
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "import seaborn as sns\n",
        "import numpy as np\n",
        "from pathlib import Path\n",
        "from sparse_but_wrong.util import cos_sims\n",
        "from sparse_but_wrong.toy_models.plotting import SEABORN_RC_CONTEXT\n",
        "\n",
        "Path(\"plots/toy_l0_small\").mkdir(parents=True, exist_ok=True)\n",
        "plt.rcParams.update({\"figure.dpi\": 150})\n",
        "sns.set_theme()\n",
        "\n",
        "# Plot both positive and negative correlation variants\n",
        "for corr_type, corr_label in [(\"pos\", \"positive\"), (\"neg\", \"negative\")]:\n",
        "    # For each frequency ratio value, calculate mean cosine similarity of latents 1-3 with feature 0\n",
        "    freq_values = sorted(saes_by_freq[corr_type].keys())\n",
        "    mean_cos_sims_per_freq = []\n",
        "    std_cos_sims_per_freq = []\n",
        "    \n",
        "    for freq_ratio in freq_values:\n",
        "        saes = saes_by_freq[corr_type][freq_ratio]\n",
        "        cos_sims_for_seeds = []\n",
        "        \n",
        "        for sae in saes:\n",
        "            # Calculate cosine similarities between SAE decoder and true features\n",
        "            dec_cos_sims = cos_sims(sae.W_dec.T, toy_model.embed.weight)\n",
        "            \n",
        "            # Get cosine similarities for latents 1, 2, 3 with feature 0\n",
        "            latents_1_3_with_feat_0 = dec_cos_sims[[1, 2, 3], 0]\n",
        "            \n",
        "            # Take the mean\n",
        "            mean_cos_sim = latents_1_3_with_feat_0.mean().item()\n",
        "            cos_sims_for_seeds.append(mean_cos_sim)\n",
        "        \n",
        "        # Average across seeds\n",
        "        mean_cos_sims_per_freq.append(np.mean(cos_sims_for_seeds))\n",
        "        std_cos_sims_per_freq.append(np.std(cos_sims_for_seeds))\n",
        "    \n",
        "    # Convert to numpy arrays for easier manipulation\n",
        "    mean_cos_sims_per_freq = np.array(mean_cos_sims_per_freq)\n",
        "    std_cos_sims_per_freq = np.array(std_cos_sims_per_freq)\n",
        "    \n",
        "    # Plot\n",
        "    with plt.rc_context(SEABORN_RC_CONTEXT):\n",
        "        plt.figure(figsize=(3, 1.5))\n",
        "        \n",
        "        # Plot the mean line\n",
        "        plt.plot(freq_values, mean_cos_sims_per_freq, linewidth=1.0)\n",
        "        \n",
        "        # Add shaded area for 1 standard deviation\n",
        "        plt.fill_between(freq_values,\n",
        "                        mean_cos_sims_per_freq - std_cos_sims_per_freq, \n",
        "                        mean_cos_sims_per_freq + std_cos_sims_per_freq, \n",
        "                        alpha=0.3)\n",
        "        \n",
        "        plt.title(f\"Feature mixing vs frequency ratio ({corr_label} correlation)\")\n",
        "        plt.xlabel(r\"Frequency ratio $P_0/P_1$\")\n",
        "        plt.ylabel(\"Mean cos sim\")\n",
        "        plt.grid(True, alpha=0.3)\n",
        "        \n",
        "        plt.tight_layout()\n",
        "        plt.savefig(f'plots/toy_l0_small/mean_cos_sim_vs_freq_{corr_type}.pdf')\n",
        "        plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "4f98677d",
      "metadata": {},
      "outputs": [],
      "source": [
        "from sparse_but_wrong.toy_models.plotting import plot_sae_feat_cos_sims, plot_sae_feat_cos_sims_seaborn\n",
        "\n",
        "plot_sae_feat_cos_sims(saes_by_freq[\"pos\"][1][0], toy_model, \"SAE\", reorder_features=True)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "11192a65",
      "metadata": {},
      "source": [
        "# Superposition noise\n",
        "\n",
        "Running without superposition noise is the simplest task for an SAE, so there is no reason to expect that if an SAE cannot learn correct latents with no superposition, that it will not have this problem after superposition noise is added. Nevertheless, we demonstrate this for completeness below."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "d43be5f7",
      "metadata": {},
      "outputs": [],
      "source": [
        "from sparse_but_wrong.toy_models.toy_model import ToyModel\n",
        "from sparse_but_wrong.util import DEFAULT_DEVICE, cos_sims\n",
        "import matplotlib.pyplot as plt\n",
        "import seaborn as sns\n",
        "from sparse_but_wrong.toy_models.plotting import SEABORN_RC_CONTEXT\n",
        "\n",
        "\n",
        "super_toy_model = ToyModel(num_feats=5, hidden_dim=20, ortho_num_steps=4).to(DEFAULT_DEVICE)\n",
        "\n",
        "feature_cos_sims = cos_sims(super_toy_model.embed.weight, super_toy_model.embed.weight).detach()\n",
        "\n",
        "plt.rcParams.update({\"figure.dpi\": 150})\n",
        "with plt.rc_context(SEABORN_RC_CONTEXT):\n",
        "    plt.figure(figsize=(2.5, 2))\n",
        "    sns.heatmap(feature_cos_sims, cmap=\"RdBu\", center=0, vmin=-1, vmax=1)\n",
        "    plt.gca().invert_yaxis()\n",
        "    plt.xlabel(\"Feature\")\n",
        "    plt.ylabel(\"Feature\")\n",
        "    plt.title(\"Feature cosine similarities (superposition)\")\n",
        "    plt.savefig(\"plots/toy_setup_small/super_toy_model_cos_sims.pdf\")\n",
        "    plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "ca4e2d1a",
      "metadata": {},
      "outputs": [],
      "source": [
        "from sae_lens import BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig\n",
        "from sparse_but_wrong.toy_models.train_toy_sae import train_toy_sae\n",
        "from sparse_but_wrong.toy_models.initialization import init_sae_to_match_model\n",
        "\n",
        "# positive correlations\n",
        "cfg = BatchTopKTrainingSAEConfig(k=1.9, d_in=super_toy_model.embed.weight.shape[0], d_sae=5)\n",
        "super_sae_pos = BatchTopKTrainingSAE(cfg)\n",
        "init_sae_to_match_model(super_sae_pos, super_toy_model)\n",
        "train_toy_sae(super_sae_pos, super_toy_model, generate_batch_pos)\n",
        "\n",
        "# negative correlations\n",
        "cfg = BatchTopKTrainingSAEConfig(k=1.9, d_in=super_toy_model.embed.weight.shape[0], d_sae=5)\n",
        "super_sae_neg = BatchTopKTrainingSAE(cfg)\n",
        "init_sae_to_match_model(super_sae_neg, super_toy_model)\n",
        "train_toy_sae(super_sae_neg, super_toy_model, generate_batch_neg)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "26845a91",
      "metadata": {},
      "outputs": [],
      "source": [
        "import plotly.express as px\n",
        "from sparse_but_wrong.toy_models.plotting import plot_sae_feat_cos_sims, plot_sae_feat_cos_sims_seaborn\n",
        "\n",
        "plot_sae_feat_cos_sims(super_sae_pos, super_toy_model, \"SAE, pos correlation, superposition\", reorder_features=True)\n",
        "plot_sae_feat_cos_sims_seaborn(super_sae_pos, super_toy_model, title=\"SAE, pos correlation, superposition\", reorder_features=True, decoder_only=True, width=5, height=2, adjust_for_superposition=True, decoder_title=None, save_path=\"plots/toy_l0_small/super_pos_corr_sae_l0_lt_true_l0_decoder_cos_sims.pdf\")\n",
        "\n",
        "plot_sae_feat_cos_sims(super_sae_neg, super_toy_model, \"SAE, neg correlation, superposition\", reorder_features=True)\n",
        "plot_sae_feat_cos_sims_seaborn(super_sae_neg, super_toy_model, title=\"SAE, neg correlation, superposition\", reorder_features=True, decoder_only=True, width=5, height=2, adjust_for_superposition=True, decoder_title=None, save_path=\"plots/toy_l0_small/super_neg_corr_sae_l0_lt_true_l0_decoder_cos_sims.pdf\")\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "0aa45568",
      "metadata": {},
      "outputs": [],
      "source": [
        "from sae_lens import BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig\n",
        "from sparse_but_wrong.toy_models.train_toy_sae import train_toy_sae\n",
        "from sparse_but_wrong.toy_models.initialization import init_sae_to_match_model\n",
        "\n",
        "cfg = BatchTopKTrainingSAEConfig(k=1.9, d_in=super_toy_model.embed.weight.shape[0], d_sae=5)\n",
        "super_sae_neg = BatchTopKTrainingSAE(cfg)\n",
        "init_sae_to_match_model(super_sae_neg, super_toy_model)\n",
        "train_toy_sae(super_sae_neg, super_toy_model, generate_batch_neg)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": ".venv",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.4"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}
