{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RDuKeGXCpdwv"
      },
      "source": [
        "# Studying Feature Hedging in single-latent SAEs\n",
        "\n",
        "In matryoshka SAEs, we assume that if an SAE is too narrow to represent both parent and child features, that the SAE will then represent just the parent with no interference from the children, effectively solving absorption. However, this turns out now to be true: narrow SAEs learn a mix of parent and child features.\n",
        "\n",
        "We refer to this phenomenon as **feature hedging**. In Feature hedging, an SAE that is too narrow to represent both a parent and child feature mixes part of the child representation into the parent latent to reduce MSE loss rather than representing the parent cleanly.\n",
        "\n",
        "This notebook explores this behavior in the simplest possible case with 1 parent and 1 child feature, and a single latent SAE. Our goal will be to get the SAE to represent just the parent feature perfectly.\n",
        "\n",
        "Hopefully, if we can solve this problem with a single-latent SAE, then we can use that to solve the Matryoshka SAE problem more generally.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rxKmuMInzQf-"
      },
      "source": [
        "## Toy setup\n",
        "\n",
        "We define a toy model with 2 mutually orthogonal features in a 50d space. We will also give our toy model no bias term to make it easy to see if the SAE decoder bias is doing something strange."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JTlSkStEzPn6"
      },
      "outputs": [],
      "source": [
        "%load_ext autoreload\n",
        "%autoreload 2\n",
        "\n",
        "from hedging_paper.toy_models.toy_model import ToyModel\n",
        "\n",
        "\n",
        "DEFAULT_D_IN = 50\n",
        "DEFAULT_D_SAE = 1\n",
        "DEFAULT_NUM_FEATS = 2\n",
        "\n",
        "toy_model = ToyModel(num_feats=DEFAULT_NUM_FEATS, hidden_dim=DEFAULT_D_IN)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vEpmZSmIzOb4"
      },
      "source": [
        "## Controlling feature firing and co-occurrence\n",
        "\n",
        "Below, we set up a function `get_training_batch()` which we can use to control how many ground-truth features we have, their firing probabilities and magnitudes, and an option `modify_firing_features` callback which can be used to modify the firing features in a batch, for instance forcing a feature to fire or not fire depending on other firing features.\n",
        "\n",
        "We'll use this to create a parent-child hierarchy between features 0 and 1, with feature 0 being the parent and feature 1 being the child. Feature 1 can only fire if feature 0 is already firing."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "j-bHiWYxxdAK"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "from functools import partial\n",
        "\n",
        "from hedging_paper.toy_models.get_training_batch import get_training_batch\n",
        "\n",
        "\n",
        "# Feature 0 fires with probability 0.25\n",
        "# Feature 1 fires with probability 0.2 if feature 0 is firing\n",
        "feat_probs = torch.tensor([0.25, 0.2])\n",
        "\n",
        "# set up the parent-child firing relationship\n",
        "def modify_feats(feats: torch.Tensor):\n",
        "    feat_0_fires = feats[:, 0] == 1\n",
        "    feats[~feat_0_fires, 1] = 0\n",
        "    return feats\n",
        "\n",
        "independent_generator = partial(get_training_batch, firing_probabilities=feat_probs)\n",
        "parent_child_generator = partial(get_training_batch, firing_probabilities=feat_probs, modify_firing_features=modify_feats)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eH6Vd19x1PuJ"
      },
      "source": [
        "let's check a sample batch of 30 samples. We see 2 features, and that feature 1 is only active if feature 0 is active."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Ot8ZkEn_1S-w",
        "outputId": "6167573e-ae39-4d75-b677-7d4afabc1311"
      },
      "outputs": [],
      "source": [
        "parent_child_generator(30)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vsiiIa-g1gKK"
      },
      "source": [
        "## SAE Training\n",
        "\n",
        "We use SAELens to train our single-latent SAE. The following is a bunch of hacky boilerplate to make our toy setup work with SAELens. You can just run this"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dHrHPPWxpdwz"
      },
      "source": [
        "## Training a single latent SAE\n",
        "\n",
        "First, let's try a standard SAE on fully independent features.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "NviVK1-yH0KT",
        "outputId": "2ab8b3a9-1b0b-42da-fb49-4b2f6628a70d"
      },
      "outputs": [],
      "source": [
        "from hedging_paper.toy_models.train_toy_sae import train_toy_sae\n",
        "from hedging_paper.toy_models.initialization import init_sae_to_match_model\n",
        "\n",
        "from hedging_paper.saes.base_sae import BaseSAE, BaseSAEConfig, BaseSAERunnerConfig\n",
        "\n",
        "cfg = BaseSAERunnerConfig(\n",
        "    context_size=500,\n",
        "    d_in=toy_model.embed.weight.shape[0],\n",
        "    d_sae=DEFAULT_D_SAE,\n",
        "    l1_coefficient=1e-3,\n",
        "    normalize_sae_decoder=False,\n",
        "    scale_sparsity_penalty_by_decoder_norm=True,\n",
        "    init_encoder_as_decoder_transpose=True,\n",
        "    apply_b_dec_to_input=True,\n",
        "    b_dec_init_method=\"zeros\",\n",
        ")\n",
        "indep_sae = BaseSAE(BaseSAEConfig.from_sae_runner_config(cfg))\n",
        "init_sae_to_match_model(indep_sae, toy_model, noise_level=0.25)\n",
        "\n",
        "\n",
        "train_toy_sae(indep_sae, toy_model, independent_generator)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 317
        },
        "id": "E9QIHfs_H74f",
        "outputId": "d10d338e-a4aa-4c78-8522-2e531da28873"
      },
      "outputs": [],
      "source": [
        "from hedging_paper.toy_models.plotting import plot_sae_feat_cos_sims_seaborn\n",
        "\n",
        "plot_sae_feat_cos_sims_seaborn(\n",
        "    indep_sae,\n",
        "    toy_model,\n",
        "    \"Independent features\",\n",
        "    height=1.5,\n",
        "    width=5,\n",
        "    show_values=True,\n",
        "    save_path=\"plots/single_latent_independent_features.pdf\",\n",
        "    one_based_indexing=True,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e-mk3TjJH-cK"
      },
      "source": [
        "That worked exactly as we would hope! The SAE perfectly recovers feature 0 with no interference from feature 1."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 335
        },
        "id": "rGrjNCy7ApMY",
        "outputId": "771132c9-b880-4702-ccbb-fa662d3808da"
      },
      "outputs": [],
      "source": [
        "from hedging_paper.toy_models.plotting import plot_b_dec_feat_cos_sims_seaborn\n",
        "\n",
        "plot_b_dec_feat_cos_sims_seaborn(\n",
        "    indep_sae,\n",
        "    toy_model,\n",
        "    # \"Independent features\",\n",
        "    height=1.5,\n",
        "    width=4,\n",
        "    show_values=True,\n",
        "    one_based_indexing=True,\n",
        "    save_path=\"plots/single_latent_independent_features_b_dec.pdf\",\n",
        ")\n",
        "print(f\"SAE b_dec magnitude: {indep_sae.b_dec.norm():.3f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Pz6ZVdGGDbkO"
      },
      "source": [
        "Interestingly, the SAE decoder bias has learned feature 1 with a reduced magnitude exactly equal to the probability of feature 1 firing. This isn't desirable behavior, but isn't terrible. Ideally the decoder should match our toy model bias of 0.\n",
        "\n",
        "This sort of hedging is incentivized by MSE loss, where it's worth it to get a small error on 80% of the times that feature 1 doesn't fire to reduce the massive squared error the 20% of times that feature 1 does fire."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "p7vunMEtpdwz",
        "outputId": "915d84fb-7d55-43eb-ebf4-dfeaf35b0853"
      },
      "outputs": [],
      "source": [
        "from hedging_paper.toy_models.train_toy_sae import train_toy_sae\n",
        "from hedging_paper.saes.base_sae import BaseSAE, BaseSAERunnerConfig, BaseSAEConfig\n",
        "from hedging_paper.toy_models.initialization import init_sae_to_match_model\n",
        "\n",
        "cfg = BaseSAERunnerConfig(\n",
        "    context_size=500,\n",
        "    d_in=toy_model.embed.weight.shape[0],\n",
        "    d_sae=DEFAULT_D_SAE,\n",
        "    l1_coefficient=1e-3,\n",
        "    normalize_sae_decoder=False,\n",
        "    scale_sparsity_penalty_by_decoder_norm=True,\n",
        "    init_encoder_as_decoder_transpose=True,\n",
        "    apply_b_dec_to_input=True,\n",
        "    b_dec_init_method=\"zeros\",\n",
        ")\n",
        "base_sae = BaseSAE(BaseSAEConfig.from_sae_runner_config(cfg))\n",
        "init_sae_to_match_model(base_sae, toy_model, noise_level=0.25)\n",
        "\n",
        "train_toy_sae(base_sae, toy_model, parent_child_generator)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 317
        },
        "id": "rVZqBx0Npdw0",
        "outputId": "d373ec8c-9d1b-443d-fa0e-9ee59d838221"
      },
      "outputs": [],
      "source": [
        "from hedging_paper.toy_models.plotting import plot_sae_feat_cos_sims_seaborn\n",
        "\n",
        "plot_sae_feat_cos_sims_seaborn(\n",
        "    base_sae,\n",
        "    toy_model,\n",
        "    \"Hierarchical features\",\n",
        "    height=1.5,\n",
        "    width=5,\n",
        "    show_values=True,\n",
        "    save_path=\"plots/single_latent_hierarchical_features.pdf\",\n",
        "    one_based_indexing=True,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bGZchH99pdw0"
      },
      "source": [
        "Sadly, we did not learn just feature 0 on its own as we hoped! Our single latent is mixing in part of feature 1 as well 😢.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Quick demo plot of feature hedging"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "from hedging_paper.toy_models.train_toy_sae import train_toy_sae\n",
        "from hedging_paper.saes.base_sae import BaseSAE, BaseSAERunnerConfig, BaseSAEConfig\n",
        "from hedging_paper.toy_models.initialization import init_sae_to_match_model\n",
        "\n",
        "cfg = BaseSAERunnerConfig(\n",
        "    context_size=500,\n",
        "    d_in=toy_model.embed.weight.shape[0],\n",
        "    d_sae=DEFAULT_D_SAE,\n",
        "    l1_coefficient=1e-3,\n",
        "    normalize_sae_decoder=False,\n",
        "    scale_sparsity_penalty_by_decoder_norm=True,\n",
        "    init_encoder_as_decoder_transpose=True,\n",
        "    apply_b_dec_to_input=True,\n",
        "    b_dec_init_method=\"zeros\",\n",
        ")\n",
        "hedge_demo_sae = BaseSAE(BaseSAEConfig.from_sae_runner_config(cfg))\n",
        "init_sae_to_match_model(hedge_demo_sae, toy_model, noise_level=0)\n",
        "\n",
        "train_toy_sae(hedge_demo_sae, toy_model, parent_child_generator)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "from hedging_paper.toy_models.plotting import plot_sae_feat_cos_sims_seaborn\n",
        "from hedging_paper.toy_models.plotting import plot_sae_feat_cos_sims\n",
        "\n",
        "\n",
        "plot_sae_feat_cos_sims_seaborn(\n",
        "    hedge_demo_sae,\n",
        "    toy_model,\n",
        "    \"Feature hedging\",\n",
        "    height=1.5,\n",
        "    width=5,\n",
        "    show_values=False,\n",
        "    save_path=\"plots/feature_hedging_example.pdf\",\n",
        "    one_based_indexing=True,\n",
        ")\n",
        "\n",
        "plot_sae_feat_cos_sims(\n",
        "    hedge_demo_sae,\n",
        "    toy_model,\n",
        "    \"Feature hedging\",\n",
        "    height=200,\n",
        "    show_values=False,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## This would cause absorption if the SAE had enough latents\n",
        "\n",
        "Let's verify that if the SAE had 2 latents instead, we'd get feature absorption instead"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "from hedging_paper.toy_models.train_toy_sae import train_toy_sae\n",
        "from hedging_paper.saes.base_sae import BaseSAE, BaseSAERunnerConfig, BaseSAEConfig\n",
        "from hedging_paper.toy_models.initialization import init_sae_to_match_model\n",
        "\n",
        "cfg = BaseSAERunnerConfig(\n",
        "    context_size=500,\n",
        "    d_in=toy_model.embed.weight.shape[0],\n",
        "    d_sae=2,\n",
        "    l1_coefficient=1e-3,\n",
        "    normalize_sae_decoder=False,\n",
        "    scale_sparsity_penalty_by_decoder_norm=True,\n",
        "    init_encoder_as_decoder_transpose=True,\n",
        "    apply_b_dec_to_input=True,\n",
        "    b_dec_init_method=\"zeros\",\n",
        ")\n",
        "abs_sae = BaseSAE(BaseSAEConfig.from_sae_runner_config(cfg))\n",
        "init_sae_to_match_model(abs_sae, toy_model, noise_level=0.1)\n",
        "\n",
        "train_toy_sae(abs_sae, toy_model, parent_child_generator, training_tokens=100_000_000)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "from hedging_paper.toy_models.plotting import plot_sae_feat_cos_sims_seaborn\n",
        "\n",
        "plot_sae_feat_cos_sims_seaborn(\n",
        "    abs_sae,\n",
        "    toy_model,\n",
        "    \"Feature absorption\",\n",
        "    height=2,\n",
        "    width=5,\n",
        "    show_values=False,\n",
        "    save_path=\"plots/feature_absorption_example.pdf\",\n",
        "    one_based_indexing=True,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Gv7IC0IrBMSI"
      },
      "source": [
        "# What if the child and parent are only loosly correlated?\n",
        "\n",
        "Above, the child can only fire if the parent is firing. What happens if we relax that constraint, so instead the child fires more with the parent than it fires on its own, but is not perfectly correlated?\n",
        "\n",
        "Below, we set up feature 1 to fire with probability 0.25 given feature 0 fires, but only probability of 0.1 of firing on its own."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "metadata": {
        "id": "9cLAt2-_BaxS"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "from functools import partial\n",
        "\n",
        "from hedging_paper.toy_models.get_training_batch import get_training_batch\n",
        "\n",
        "# Feature 0 fires with probability 0.25\n",
        "# Feature 1 fires with probability 0.2 if feature 0 is firing\n",
        "feat_probs = torch.tensor([0.25, 0.2])\n",
        "# Feature 1 fires with probability 0.1 if feature 0 is not firing\n",
        "solo_firing_prob = 0.1\n",
        "\n",
        "# set up the parent-child firing relationship\n",
        "def partial_cooccurence(feats: torch.Tensor):\n",
        "    solo_firings = firing_features = torch.bernoulli(\n",
        "        torch.tensor(solo_firing_prob).unsqueeze(0).expand(feats.shape[0])\n",
        "    )\n",
        "\n",
        "    feat_0_fires = feats[:, 0] == 1\n",
        "    feats[~feat_0_fires, 1] = solo_firings[~feat_0_fires]\n",
        "    return feats\n",
        "\n",
        "partial_cooccurrence_generator = partial(get_training_batch, firing_probabilities=feat_probs, modify_firing_features=partial_cooccurence)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "k_zcp4aYDjGK"
      },
      "source": [
        "let's check a sample batch of 30 samples. We see 2 features, and that feature 1 is only active if feature 0 is active."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "6HtrLHLJDifA",
        "outputId": "d5d9caf6-4c12-4b6c-d2c0-b0127e9aebf9"
      },
      "outputs": [],
      "source": [
        "partial_cooccurrence_generator(30)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4xOBEHe9EPUo"
      },
      "source": [
        "Next, let's train a single-latent SAE on these features"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "from hedging_paper.toy_models.train_toy_sae import train_toy_sae\n",
        "from hedging_paper.saes.base_sae import BaseSAE, BaseSAERunnerConfig\n",
        "from hedging_paper.saes.base_sae import BaseSAEConfig\n",
        "from hedging_paper.toy_models.initialization import init_sae_to_match_model\n",
        "\n",
        "\n",
        "cfg = BaseSAERunnerConfig(\n",
        "    context_size=500,\n",
        "    d_in=toy_model.embed.weight.shape[0],\n",
        "    d_sae=DEFAULT_D_SAE,\n",
        "    l1_coefficient=1e-3,\n",
        "    normalize_sae_decoder=False,\n",
        "    scale_sparsity_penalty_by_decoder_norm=True,\n",
        "    init_encoder_as_decoder_transpose=True,\n",
        "    apply_b_dec_to_input=True,\n",
        "    b_dec_init_method=\"zeros\",\n",
        ")\n",
        "partial_sae_low_l1 = BaseSAE(BaseSAEConfig.from_sae_runner_config(cfg))\n",
        "init_sae_to_match_model(partial_sae_low_l1, toy_model, noise_level=0.25)\n",
        "\n",
        "train_toy_sae(partial_sae_low_l1, toy_model, partial_cooccurrence_generator)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "from hedging_paper.toy_models.plotting import plot_sae_feat_cos_sims_seaborn\n",
        "\n",
        "plot_sae_feat_cos_sims_seaborn(\n",
        "    partial_sae_low_l1,\n",
        "    toy_model,\n",
        "    \"Correlated features, low L1 penalty\",\n",
        "    height=1.5,\n",
        "    width=5,\n",
        "    show_values=True,\n",
        "    save_path=\"plots/single_latent_correlated_features_low_l1.pdf\",\n",
        "    one_based_indexing=True,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "hXCRZ0_5ESR8",
        "outputId": "c9376a01-1cdf-44fc-8c77-716fc233353a"
      },
      "outputs": [],
      "source": [
        "from hedging_paper.toy_models.train_toy_sae import train_toy_sae\n",
        "from hedging_paper.saes.base_sae import BaseSAE, BaseSAERunnerConfig\n",
        "from hedging_paper.saes.base_sae import BaseSAEConfig\n",
        "from hedging_paper.toy_models.initialization import init_sae_to_match_model\n",
        "\n",
        "\n",
        "cfg = BaseSAERunnerConfig(\n",
        "    context_size=500,\n",
        "    d_in=toy_model.embed.weight.shape[0],\n",
        "    d_sae=DEFAULT_D_SAE,\n",
        "    l1_coefficient=1e-1,\n",
        "    normalize_sae_decoder=False,\n",
        "    scale_sparsity_penalty_by_decoder_norm=True,\n",
        "    init_encoder_as_decoder_transpose=True,\n",
        "    apply_b_dec_to_input=True,\n",
        "    b_dec_init_method=\"zeros\",\n",
        ")\n",
        "partial_sae_high_l1 = BaseSAE(BaseSAEConfig.from_sae_runner_config(cfg))\n",
        "init_sae_to_match_model(partial_sae_high_l1, toy_model, noise_level=0.25)\n",
        "\n",
        "train_toy_sae(partial_sae_high_l1, toy_model, partial_cooccurrence_generator)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 317
        },
        "id": "JC2plBhOEwpc",
        "outputId": "56ae3c5f-7479-479a-e1f0-96c69d7d9295"
      },
      "outputs": [],
      "source": [
        "from hedging_paper.toy_models.plotting import plot_sae_feat_cos_sims_seaborn\n",
        "\n",
        "plot_sae_feat_cos_sims_seaborn(\n",
        "    partial_sae_high_l1,\n",
        "    toy_model,\n",
        "    \"Correlated features, high L1 penalty\",\n",
        "    height=1.5,\n",
        "    width=5,\n",
        "    show_values=True,\n",
        "    save_path=\"plots/single_latent_correlated_features_high_l1.pdf\",\n",
        "    one_based_indexing=True,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dJU0ywSqU4cy"
      },
      "source": [
        "### What if we train a wide-enough SAE on these features?"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "from hedging_paper.toy_models.train_toy_sae import train_toy_sae\n",
        "from hedging_paper.saes.base_sae import BaseSAE, BaseSAERunnerConfig\n",
        "from hedging_paper.saes.base_sae import BaseSAEConfig\n",
        "from hedging_paper.toy_models.initialization import init_sae_to_match_model\n",
        "\n",
        "\n",
        "cfg = BaseSAERunnerConfig(\n",
        "    context_size=500,\n",
        "    d_in=toy_model.embed.weight.shape[0],\n",
        "    d_sae=2,\n",
        "    l1_coefficient=1e-3,\n",
        "    normalize_sae_decoder=False,\n",
        "    scale_sparsity_penalty_by_decoder_norm=True,\n",
        "    init_encoder_as_decoder_transpose=True,\n",
        "    apply_b_dec_to_input=True,\n",
        "    b_dec_init_method=\"zeros\",\n",
        ")\n",
        "partial_sae_wide = BaseSAE(BaseSAEConfig.from_sae_runner_config(cfg))\n",
        "init_sae_to_match_model(partial_sae_wide, toy_model, noise_level=0.25)\n",
        "\n",
        "train_toy_sae(partial_sae_wide, toy_model, partial_cooccurrence_generator)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "from hedging_paper.toy_models.plotting import plot_sae_feat_cos_sims_seaborn\n",
        "\n",
        "plot_sae_feat_cos_sims_seaborn(\n",
        "    partial_sae_wide,\n",
        "    toy_model,\n",
        "    \"Correlated features, full-width SAE\",\n",
        "    height=1.5,\n",
        "    width=5,\n",
        "    show_values=True,\n",
        "    save_path=\"plots/single_latent_correlated_features_full_width.pdf\",\n",
        "    one_based_indexing=True,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gsjAUZ_raHpg"
      },
      "source": [
        "## What happens if the co-occurrence is flipped?\n",
        "\n",
        "This is very bad, but what happens if the feature fires more on its own then it co-occurs with the parent? Do we still see a messed up SAE latent?"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 21,
      "metadata": {
        "id": "6DWNHNhJaGVX"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "from functools import partial\n",
        "from hedging_paper.toy_models.get_training_batch import get_training_batch\n",
        "\n",
        "# Feature 0 fires with probability 0.25\n",
        "# Feature 1 fires with probability 0.1 if feature 0 is firing\n",
        "feat_probs = torch.tensor([0.25, 0.1])\n",
        "# Feature 1 fires with probability 0.2 if feature 0 is not firing\n",
        "solo_firing_prob = 0.2\n",
        "\n",
        "# set up the parent-child firing relationship\n",
        "def partial_cooccurence(feats: torch.Tensor):\n",
        "    solo_firings = torch.bernoulli(\n",
        "        torch.tensor(solo_firing_prob).unsqueeze(0).expand(feats.shape[0])\n",
        "    )\n",
        "\n",
        "    feat_0_fires = feats[:, 0] == 1\n",
        "    feats[~feat_0_fires, 1] = solo_firings[~feat_0_fires]\n",
        "    return feats\n",
        "\n",
        "inv_partial_cooccurrence_generator = partial(get_training_batch, firing_probabilities=feat_probs, modify_firing_features=partial_cooccurence)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "mBo8YnnKalFy",
        "outputId": "35e0e63e-3af3-46d4-c60c-f5698eca035d"
      },
      "outputs": [],
      "source": [
        "from hedging_paper.toy_models.train_toy_sae import train_toy_sae\n",
        "from hedging_paper.saes.base_sae import BaseSAE, BaseSAERunnerConfig\n",
        "from hedging_paper.saes.base_sae import BaseSAEConfig\n",
        "from hedging_paper.toy_models.initialization import init_sae_to_match_model\n",
        "\n",
        "\n",
        "cfg = BaseSAERunnerConfig(\n",
        "    context_size=500,\n",
        "    d_in=toy_model.embed.weight.shape[0],\n",
        "    d_sae=DEFAULT_D_SAE,\n",
        "    l1_coefficient=2e-2,\n",
        "    normalize_sae_decoder=False,\n",
        "    scale_sparsity_penalty_by_decoder_norm=True,\n",
        "    init_encoder_as_decoder_transpose=True,\n",
        "    apply_b_dec_to_input=True,\n",
        "    b_dec_init_method=\"zeros\",\n",
        ")\n",
        "partial_sae_inv = BaseSAE(BaseSAEConfig.from_sae_runner_config(cfg))\n",
        "init_sae_to_match_model(partial_sae_inv, toy_model, noise_level=0.25)\n",
        "\n",
        "train_toy_sae(partial_sae_inv, toy_model, inv_partial_cooccurrence_generator)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 317
        },
        "id": "xLVR7r-Kar-r",
        "outputId": "4ae0f47c-146c-43a8-deed-d11018ea9dba"
      },
      "outputs": [],
      "source": [
        "from hedging_paper.toy_models.plotting import plot_sae_feat_cos_sims_seaborn\n",
        "\n",
        "plot_sae_feat_cos_sims_seaborn(\n",
        "    partial_sae_inv,\n",
        "    toy_model,\n",
        "    \"Standard SAE - anti-correlated features\",\n",
        "    height=1.5,\n",
        "    width=5,\n",
        "    show_values=True,\n",
        "    save_path=\"plots/single_latent_anti_correlated_features.pdf\",\n",
        "    one_based_indexing=True,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8PDT7-BYbC-B"
      },
      "source": [
        "The SAE latent now contains a negative component of the child feature, just due to partial co-occurrence! This is so very bad 😱"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 335
        },
        "id": "h9m-rVsFEnX2",
        "outputId": "b83c05fa-171d-4217-e329-d9ae1036084f"
      },
      "outputs": [],
      "source": [
        "\n",
        "from hedging_paper.toy_models.plotting import plot_b_dec_feat_cos_sims\n",
        "\n",
        "plot_b_dec_feat_cos_sims(partial_sae_inv, toy_model, \"Standard SAE - anti-correlated features\", show_values=True)\n",
        "print(f\"SAE b_dec magnitude: {partial_sae_inv.b_dec.norm():.3f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jqpkwhMBFQhZ"
      },
      "source": [
        "This works because the SAE decoder bias still encodes a smaller magnitude version of feature 1, so the SAE latent encoding a negative component of feature 1 reduces the magnitude of the hedging when the SAE latent fires. The SAE is abusing the knowledge that feature 0 firing means feature 1 is less likely to be firing to reduce the MSE loss penalty for having this always-active component of feature 1 in the SAE decoder bias."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### What if the SAE is wide enough to represent all the features?"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "from hedging_paper.toy_models.train_toy_sae import train_toy_sae\n",
        "from hedging_paper.saes.base_sae import BaseSAE, BaseSAERunnerConfig\n",
        "from hedging_paper.saes.base_sae import BaseSAEConfig\n",
        "from hedging_paper.toy_models.initialization import init_sae_to_match_model\n",
        "\n",
        "\n",
        "cfg = BaseSAERunnerConfig(\n",
        "    context_size=500,\n",
        "    d_in=toy_model.embed.weight.shape[0],\n",
        "    d_sae=2,\n",
        "    l1_coefficient=2e-2,\n",
        "    normalize_sae_decoder=False,\n",
        "    scale_sparsity_penalty_by_decoder_norm=True,\n",
        "    init_encoder_as_decoder_transpose=True,\n",
        "    apply_b_dec_to_input=True,\n",
        "    b_dec_init_method=\"zeros\",\n",
        ")\n",
        "partial_sae_inv_wide = BaseSAE(BaseSAEConfig.from_sae_runner_config(cfg))\n",
        "init_sae_to_match_model(partial_sae_inv_wide, toy_model, noise_level=0.25)\n",
        "\n",
        "train_toy_sae(partial_sae_inv_wide, toy_model, inv_partial_cooccurrence_generator)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "from hedging_paper.toy_models.plotting import plot_sae_feat_cos_sims_seaborn\n",
        "\n",
        "plot_sae_feat_cos_sims_seaborn(\n",
        "    partial_sae_inv_wide,\n",
        "    toy_model,\n",
        "    \"Full-width SAE - anti-correlated features\",\n",
        "    height=1.5,\n",
        "    width=5,\n",
        "    show_values=True,\n",
        "    save_path=\"plots/single_latent_anti_correlated_features_wide.pdf\",\n",
        "    one_based_indexing=True,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CMKk3fYAwjWW"
      },
      "source": [
        "## Plotting loss curves\n",
        "\n",
        "Let's plot the loss as a function of the angle of a vector between the two features. We'll normalize the x-axis to be between 0 and 1, where 0 is the parent feature direction and 1 is the child feature direction. 0.5 is the absorption solution, mixing parent and child together.\n",
        "\n",
        "Below we set up the plotting and calculation helpers for these loss curves. Feel free to just run this.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 27,
      "metadata": {
        "id": "aCPpPuHWwjWW"
      },
      "outputs": [],
      "source": [
        "from typing import Callable\n",
        "\n",
        "import torch\n",
        "\n",
        "@torch.no_grad()\n",
        "def calc_loss_curve(\n",
        "    loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],\n",
        "    parent_only_prob: float,\n",
        "    parent_and_child_prob: float,\n",
        "    steps: int = 100,\n",
        "    sparsity_coefficient: float = 0.0,\n",
        "    sparsity_p: float = 1.0,\n",
        "    rand_vecs: bool = False,\n",
        ") -> list[float]:\n",
        "    if rand_vecs:\n",
        "        parent_vec = torch.randn(10).float()\n",
        "        child_vec = torch.randn(10).float()\n",
        "        child_vec = child_vec - (child_vec @ parent_vec) * parent_vec / parent_vec.norm()**2\n",
        "        child_vec = child_vec / child_vec.norm()\n",
        "        parent_vec = parent_vec / parent_vec.norm()\n",
        "        parent_and_child_vec = parent_vec + child_vec\n",
        "    else:\n",
        "        parent_vec = torch.tensor([1, 0]).float()\n",
        "        child_vec = torch.tensor([0, 1]).float()\n",
        "        parent_and_child_vec = parent_vec + child_vec\n",
        "\n",
        "    losses = []\n",
        "    for i in range(steps):\n",
        "        portion = i / (steps - 1)\n",
        "        test_latent = parent_vec * (1 - portion) + child_vec * portion\n",
        "        test_latent = test_latent / test_latent.norm()\n",
        "\n",
        "        encode = lambda input_act: (input_act @ test_latent).relu()\n",
        "        decode = lambda hidden_act: hidden_act * test_latent\n",
        "\n",
        "        def calc_loss(input_act: torch.Tensor) -> torch.Tensor:\n",
        "            hidden_act = encode(input_act)\n",
        "            recons = decode(hidden_act)\n",
        "            return loss(recons, input_act) + sparsity_coefficient * hidden_act.norm(\n",
        "                p=sparsity_p\n",
        "            )\n",
        "\n",
        "        parent_loss = calc_loss(parent_vec)\n",
        "        parent_and_child_loss = calc_loss(parent_and_child_vec)\n",
        "        expected_loss = (\n",
        "            parent_loss * parent_only_prob\n",
        "            + parent_and_child_loss * parent_and_child_prob\n",
        "        ).item()\n",
        "        losses.append(expected_loss)\n",
        "    return losses"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MT-MbSgC8A7y"
      },
      "source": [
        "## Different parent vs child firing probabilities result in different loss landscapes\n",
        "\n",
        "First, let's see what happens if the parent fires on its own 30% of the time, but 15% of the time it fires along with the child"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 542
        },
        "id": "N86kmPjb8ABN",
        "outputId": "f51abd9b-a5aa-4be2-8030-9b96feef643e"
      },
      "outputs": [],
      "source": [
        "from hedging_paper.loss_curves import plot_loss_curve_seaborn\n",
        "\n",
        "plot_loss_curve_seaborn(\n",
        "    parent_only_prob=0.3,\n",
        "    parent_and_child_prob=0.1,\n",
        "    sparsity_coefficient=[0.0, 0.1],\n",
        "    sparsity_p=1.0,\n",
        "    subtitle=r\"skew parent (p($f_1+f_2$) < p($f_1$))\",\n",
        "    rand_vecs=True,\n",
        "    width=3,\n",
        "    height=2,\n",
        "    save_path=\"plots/single_latent_loss_curves_skew_parent.pdf\",\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "from hedging_paper.loss_curves import plot_loss_curve_seaborn\n",
        "\n",
        "plot_loss_curve_seaborn(\n",
        "    parent_only_prob=0.1,\n",
        "    parent_and_child_prob=0.3,\n",
        "    sparsity_coefficient=[0.0, 0.1],\n",
        "    sparsity_p=1.0,\n",
        "    subtitle=r\"skew child (p($f_1+f_2$) > p($f_1$))\",\n",
        "    rand_vecs=True,\n",
        "    width=3,\n",
        "    height=2,\n",
        "    save_path=\"plots/single_latent_loss_curves_skew_child.pdf\",\n",
        ")"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [
        "bDeRplIZ_Cxq",
        "oZfr2ude2HEi"
      ],
      "provenance": []
    },
    "kernelspec": {
      "display_name": "sae-rethink-3xX6iC0B-py3.11",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.4"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
