{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RDuKeGXCpdwv"
      },
      "source": [
        "# Studying Feature Hedging in toy model SAEs\n",
        "\n",
        "In matryoshka SAEs, we assume that if an SAE is too narrow to represent both parent and child features, that the SAE will then represent just the parent with no interference from the children, effectively solving absorption. However, this turns out now to be true: narrow SAEs learn a mix of parent and child features.\n",
        "\n",
        "We refer to this phenomenon as **feature hedging**. In Feature hedging, an SAE that is too narrow to represent both a parent and child feature mixes part of the child representation into the parent latent to reduce MSE loss rather than representing the parent cleanly.\n",
        "\n",
        "This notebook explores this behavior in a simple toy model with 4 features. There is 1 parent and 1 child feature, and the remaining 2 features are independent. We will train an SAE with 3 latents, so not wide enough to represent all 4 features. Our goal will be to get the SAE to represent each of the 3 features perfectly, without hedging.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rxKmuMInzQf-"
      },
      "source": [
        "## Toy setup\n",
        "\n",
        "We define a toy model with 4 mutually orthogonal features in a 50d space."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JTlSkStEzPn6"
      },
      "outputs": [],
      "source": [
        "%load_ext autoreload\n",
        "%autoreload 2\n",
        "\n",
        "from hedging_paper.toy_models.toy_model import ToyModel\n",
        "\n",
        "DEFAULT_D_IN = 50\n",
        "DEFAULT_D_SAE = 3\n",
        "DEFAULT_NUM_FEATS = 4\n",
        "\n",
        "toy_model = ToyModel(num_feats=DEFAULT_NUM_FEATS, hidden_dim=DEFAULT_D_IN)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vEpmZSmIzOb4"
      },
      "source": [
        "## Controlling feature firing and co-occurrence\n",
        "\n",
        "Below, we set up a function `get_training_batch()` which we can use to control how many ground-truth features we have, their firing probabilities and magnitudes, and an option `modify_firing_features` callback which can be used to modify the firing features in a batch, for instance forcing a feature to fire or not fire depending on other firing features.\n",
        "\n",
        "We'll use this to create a parent-child hierarchy between features 3 and 4, with feature 3 being the parent and feature 4 being the child. Feature 4 can only fire if feature 3 is already firing."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "j-bHiWYxxdAK"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "from functools import partial\n",
        "\n",
        "from hedging_paper.toy_models.get_training_batch import get_training_batch\n",
        "\n",
        "\n",
        "# Features 1, 2, and 3 fire with probability 0.25\n",
        "# Feature 4 fires with probability 0.2 only if feature 3 fires\n",
        "feat_probs = torch.tensor([0.25, 0.25, 0.25, 0.2])\n",
        "\n",
        "# set up the parent-child firing relationship\n",
        "def modify_feats(feats: torch.Tensor):\n",
        "    parent_feat_fires = feats[:, 2] == 1\n",
        "    feats[~parent_feat_fires, 3] = 0\n",
        "    return feats\n",
        "\n",
        "independent_generator = partial(get_training_batch, firing_probabilities=feat_probs)\n",
        "parent_child_generator = partial(get_training_batch, firing_probabilities=feat_probs, modify_firing_features=modify_feats)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vsiiIa-g1gKK"
      },
      "source": [
        "## SAE Training\n",
        "\n",
        "We use SAELens to train our 3-latent SAE."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dHrHPPWxpdwz"
      },
      "source": [
        "## Training a 3-latent SAE\n",
        "\n",
        "First, let's try a standard SAE on fully independent features.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "NviVK1-yH0KT",
        "outputId": "2ab8b3a9-1b0b-42da-fb49-4b2f6628a70d"
      },
      "outputs": [],
      "source": [
        "from hedging_paper.toy_models.train_toy_sae import train_toy_sae\n",
        "from hedging_paper.toy_models.initialization import init_sae_to_match_model\n",
        "\n",
        "from hedging_paper.saes.base_sae import BaseSAE, BaseSAEConfig, BaseSAERunnerConfig\n",
        "\n",
        "cfg = BaseSAERunnerConfig(\n",
        "    context_size=500,\n",
        "    d_in=toy_model.embed.weight.shape[0],\n",
        "    d_sae=DEFAULT_D_SAE,\n",
        "    l1_coefficient=1e-2,\n",
        "    normalize_sae_decoder=False,\n",
        "    scale_sparsity_penalty_by_decoder_norm=True,\n",
        "    init_encoder_as_decoder_transpose=True,\n",
        "    apply_b_dec_to_input=True,\n",
        "    b_dec_init_method=\"zeros\",\n",
        ")\n",
        "indep_sae = BaseSAE(BaseSAEConfig.from_sae_runner_config(cfg))\n",
        "\n",
        "train_toy_sae(indep_sae, toy_model, independent_generator)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 317
        },
        "id": "E9QIHfs_H74f",
        "outputId": "d10d338e-a4aa-4c78-8522-2e531da28873"
      },
      "outputs": [],
      "source": [
        "from hedging_paper.toy_models.plotting import plot_sae_feat_cos_sims_seaborn\n",
        "\n",
        "plot_sae_feat_cos_sims_seaborn(\n",
        "    indep_sae,\n",
        "    toy_model,\n",
        "    \"Independent features\",\n",
        "    height=1.5,\n",
        "    width=5,\n",
        "    show_values=False,\n",
        "    save_path=\"plots/three_latent_independent_features.pdf\",\n",
        "    one_based_indexing=True,\n",
        "    reorder_latents=True,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e-mk3TjJH-cK"
      },
      "source": [
        "That worked exactly as we would hope! The SAE perfectly recovers feature 0 with no interference from feature 1."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 335
        },
        "id": "rGrjNCy7ApMY",
        "outputId": "771132c9-b880-4702-ccbb-fa662d3808da"
      },
      "outputs": [],
      "source": [
        "from hedging_paper.toy_models.plotting import plot_b_dec_feat_cos_sims_seaborn\n",
        "\n",
        "plot_b_dec_feat_cos_sims_seaborn(\n",
        "    indep_sae,\n",
        "    toy_model,\n",
        "    # \"Independent features\",\n",
        "    height=1.5,\n",
        "    width=4,\n",
        "    show_values=False,\n",
        "    one_based_indexing=True,\n",
        "    save_path=\"plots/three_latent_independent_features_b_dec.pdf\",\n",
        ")\n",
        "print(f\"SAE b_dec magnitude: {indep_sae.b_dec.norm():.3f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Pz6ZVdGGDbkO"
      },
      "source": [
        "Interestingly, the SAE decoder bias has learned feature 1 with a reduced magnitude exactly equal to the probability of feature 1 firing. This isn't desirable behavior, but isn't terrible. Ideally the decoder should match our toy model bias of 0.\n",
        "\n",
        "This sort of hedging is incentivized by MSE loss, where it's worth it to get a small error on 80% of the times that feature 1 doesn't fire to reduce the massive squared error the 20% of times that feature 1 does fire."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "from hedging_paper.toy_models.train_toy_sae import train_toy_sae\n",
        "from hedging_paper.toy_models.initialization import init_sae_to_match_model\n",
        "\n",
        "from hedging_paper.saes.base_sae import BaseSAE, BaseSAEConfig, BaseSAERunnerConfig\n",
        "\n",
        "cfg = BaseSAERunnerConfig(\n",
        "    context_size=500,\n",
        "    d_in=toy_model.embed.weight.shape[0],\n",
        "    d_sae=toy_model.embed.weight.shape[1],\n",
        "    l1_coefficient=1e-2,\n",
        "    normalize_sae_decoder=False,\n",
        "    scale_sparsity_penalty_by_decoder_norm=True,\n",
        "    init_encoder_as_decoder_transpose=True,\n",
        "    apply_b_dec_to_input=True,\n",
        "    b_dec_init_method=\"zeros\",\n",
        ")\n",
        "indep_sae_full = BaseSAE(BaseSAEConfig.from_sae_runner_config(cfg))\n",
        "\n",
        "train_toy_sae(indep_sae_full, toy_model, independent_generator)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "from hedging_paper.toy_models.plotting import plot_sae_feat_cos_sims_seaborn\n",
        "\n",
        "plot_sae_feat_cos_sims_seaborn(\n",
        "    indep_sae_full,\n",
        "    toy_model,\n",
        "    \"Independent features\",\n",
        "    height=1.75,\n",
        "    width=5,\n",
        "    show_values=False,\n",
        "    save_path=\"plots/three_latent_independent_features_full_width.pdf\",\n",
        "    one_based_indexing=True,\n",
        "    reorder_latents=True,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "p7vunMEtpdwz",
        "outputId": "915d84fb-7d55-43eb-ebf4-dfeaf35b0853"
      },
      "outputs": [],
      "source": [
        "from hedging_paper.toy_models.train_toy_sae import train_toy_sae\n",
        "from hedging_paper.saes.base_sae import BaseSAE, BaseSAERunnerConfig, BaseSAEConfig\n",
        "\n",
        "cfg = BaseSAERunnerConfig(\n",
        "    context_size=500,\n",
        "    d_in=toy_model.embed.weight.shape[0],\n",
        "    d_sae=DEFAULT_D_SAE,\n",
        "    l1_coefficient=1e-2,\n",
        "    normalize_sae_decoder=False,\n",
        "    scale_sparsity_penalty_by_decoder_norm=True,\n",
        "    init_encoder_as_decoder_transpose=True,\n",
        "    apply_b_dec_to_input=True,\n",
        "    b_dec_init_method=\"zeros\",\n",
        ")\n",
        "hier_sae = BaseSAE(BaseSAEConfig.from_sae_runner_config(cfg))\n",
        "\n",
        "train_toy_sae(hier_sae, toy_model, parent_child_generator)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 317
        },
        "id": "rVZqBx0Npdw0",
        "outputId": "d373ec8c-9d1b-443d-fa0e-9ee59d838221"
      },
      "outputs": [],
      "source": [
        "from hedging_paper.toy_models.plotting import plot_sae_feat_cos_sims_seaborn\n",
        "\n",
        "plot_sae_feat_cos_sims_seaborn(\n",
        "    hier_sae,\n",
        "    toy_model,\n",
        "    \"Hierarchical features\",\n",
        "    height=1.5,\n",
        "    width=5,\n",
        "    save_path=\"plots/three_latent_hierarchical_features.pdf\",\n",
        "    one_based_indexing=True,\n",
        "    reorder_latents=True,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bGZchH99pdw0"
      },
      "source": [
        "Sadly, we did not learn just feature 0 on its own as we hoped! Our single latent is mixing in part of feature 1 as well 😢.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "from hedging_paper.toy_models.train_toy_sae import train_toy_sae\n",
        "from hedging_paper.saes.base_sae import BaseSAE, BaseSAERunnerConfig, BaseSAEConfig\n",
        "\n",
        "cfg = BaseSAERunnerConfig(\n",
        "    context_size=500,\n",
        "    d_in=toy_model.embed.weight.shape[0],\n",
        "    d_sae=toy_model.embed.weight.shape[1],\n",
        "    l1_coefficient=1e-2,\n",
        "    normalize_sae_decoder=False,\n",
        "    scale_sparsity_penalty_by_decoder_norm=True,\n",
        "    init_encoder_as_decoder_transpose=True,\n",
        "    apply_b_dec_to_input=True,\n",
        "    b_dec_init_method=\"zeros\",\n",
        ")\n",
        "hier_sae_full = BaseSAE(BaseSAEConfig.from_sae_runner_config(cfg))\n",
        "\n",
        "train_toy_sae(hier_sae_full, toy_model, parent_child_generator)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "from hedging_paper.toy_models.plotting import plot_sae_feat_cos_sims_seaborn\n",
        "\n",
        "plot_sae_feat_cos_sims_seaborn(\n",
        "    hier_sae_full,\n",
        "    toy_model,\n",
        "    \"Hierarchical features\",\n",
        "    height=1.75,\n",
        "    width=5,\n",
        "    save_path=\"plots/three_latent_hierarchical_features_full_width.pdf\",\n",
        "    one_based_indexing=True,\n",
        "    reorder_latents=True,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Quick demo plot of feature hedging"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "from hedging_paper.toy_models.train_toy_sae import train_toy_sae\n",
        "from hedging_paper.saes.base_sae import BaseSAE, BaseSAERunnerConfig, BaseSAEConfig\n",
        "from hedging_paper.toy_models.initialization import init_sae_to_match_model\n",
        "\n",
        "cfg = BaseSAERunnerConfig(\n",
        "    context_size=500,\n",
        "    d_in=toy_model.embed.weight.shape[0],\n",
        "    d_sae=DEFAULT_D_SAE,\n",
        "    l1_coefficient=1e-2,\n",
        "    normalize_sae_decoder=False,\n",
        "    scale_sparsity_penalty_by_decoder_norm=True,\n",
        "    init_encoder_as_decoder_transpose=True,\n",
        "    apply_b_dec_to_input=True,\n",
        "    b_dec_init_method=\"zeros\",\n",
        ")\n",
        "hedge_demo_sae = BaseSAE(BaseSAEConfig.from_sae_runner_config(cfg))\n",
        "init_sae_to_match_model(hedge_demo_sae, toy_model, noise_level=0)\n",
        "\n",
        "train_toy_sae(hedge_demo_sae, toy_model, parent_child_generator)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "from hedging_paper.toy_models.plotting import plot_sae_feat_cos_sims_seaborn\n",
        "from hedging_paper.toy_models.plotting import plot_sae_feat_cos_sims\n",
        "\n",
        "\n",
        "plot_sae_feat_cos_sims_seaborn(\n",
        "    hedge_demo_sae,\n",
        "    toy_model,\n",
        "    \"Feature hedging\",\n",
        "    height=1.5,\n",
        "    width=5,\n",
        "    show_values=False,\n",
        "    save_path=\"plots/three_latent_feature_hedging_example.pdf\",\n",
        "    one_based_indexing=True,\n",
        "    reorder_latents=True,\n",
        ")\n",
        "\n",
        "plot_sae_feat_cos_sims(\n",
        "    hedge_demo_sae,\n",
        "    toy_model,\n",
        "    \"Feature hedging\",\n",
        "    height=200,\n",
        "    show_values=False,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## This would cause absorption if the SAE had enough latents\n",
        "\n",
        "Let's verify that if the SAE had 4 latents instead, we'd get feature absorption instead"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "from hedging_paper.toy_models.train_toy_sae import train_toy_sae\n",
        "from hedging_paper.saes.base_sae import BaseSAE, BaseSAERunnerConfig, BaseSAEConfig\n",
        "from hedging_paper.toy_models.initialization import init_sae_to_match_model\n",
        "\n",
        "cfg = BaseSAERunnerConfig(\n",
        "    context_size=500,\n",
        "    d_in=toy_model.embed.weight.shape[0],\n",
        "    d_sae=toy_model.embed.weight.shape[1],\n",
        "    l1_coefficient=1e-2,\n",
        "    normalize_sae_decoder=False,\n",
        "    scale_sparsity_penalty_by_decoder_norm=True,\n",
        "    init_encoder_as_decoder_transpose=True,\n",
        "    apply_b_dec_to_input=True,\n",
        "    b_dec_init_method=\"zeros\",\n",
        ")\n",
        "abs_sae = BaseSAE(BaseSAEConfig.from_sae_runner_config(cfg))\n",
        "init_sae_to_match_model(abs_sae, toy_model, noise_level=0.1)\n",
        "\n",
        "train_toy_sae(abs_sae, toy_model, parent_child_generator, training_tokens=100_000_000)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "from hedging_paper.toy_models.plotting import plot_sae_feat_cos_sims_seaborn\n",
        "\n",
        "plot_sae_feat_cos_sims_seaborn(\n",
        "    abs_sae,\n",
        "    toy_model,\n",
        "    \"Feature absorption\",\n",
        "    height=1.75,\n",
        "    width=5,\n",
        "    show_values=False,\n",
        "    save_path=\"plots/four_latent_feature_absorption_example.pdf\",\n",
        "    one_based_indexing=True,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Gv7IC0IrBMSI"
      },
      "source": [
        "# What if the child and parent are only loosly correlated?\n",
        "\n",
        "Above, the child can only fire if the parent is firing. What happens if we relax that constraint, so instead the child fires more with the parent than it fires on its own, but is not perfectly correlated?\n",
        "\n",
        "Below, we set up feature 1 to fire with probability 0.25 given feature 0 fires, but only probability of 0.1 of firing on its own."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9cLAt2-_BaxS"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "from functools import partial\n",
        "\n",
        "from hedging_paper.toy_models.get_training_batch import get_training_batch\n",
        "\n",
        "# Features 1-3 fire with probability 0.25\n",
        "# Feature 4 fires with probability 0.2 only if feature 3 fires\n",
        "feat_probs = torch.tensor([0.25, 0.25, 0.25, 0.2])\n",
        "# Feature 4 fires with probability 0.1 if feature 3 is not firing\n",
        "solo_firing_prob = 0.1\n",
        "\n",
        "# set up the parent-child firing relationship\n",
        "def partial_cooccurence(feats: torch.Tensor):\n",
        "    solo_firings = torch.bernoulli(\n",
        "        torch.tensor(solo_firing_prob).unsqueeze(0).expand(feats.shape[0])\n",
        "    )\n",
        "\n",
        "    feat_3_fires = feats[:, 2] == 1\n",
        "    feats[~feat_3_fires, 3] = solo_firings[~feat_3_fires]\n",
        "    return feats\n",
        "\n",
        "partial_cooccurrence_generator = partial(get_training_batch, firing_probabilities=feat_probs, modify_firing_features=partial_cooccurence)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4xOBEHe9EPUo"
      },
      "source": [
        "Next, let's train a single-latent SAE on these features"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "from hedging_paper.toy_models.train_toy_sae import train_toy_sae\n",
        "from hedging_paper.saes.base_sae import BaseSAE, BaseSAERunnerConfig\n",
        "from hedging_paper.saes.base_sae import BaseSAEConfig\n",
        "from hedging_paper.toy_models.initialization import init_sae_to_match_model\n",
        "\n",
        "\n",
        "cfg = BaseSAERunnerConfig(\n",
        "    context_size=500,\n",
        "    d_in=toy_model.embed.weight.shape[0],\n",
        "    d_sae=DEFAULT_D_SAE,\n",
        "    l1_coefficient=1e-2,\n",
        "    normalize_sae_decoder=False,\n",
        "    scale_sparsity_penalty_by_decoder_norm=True,\n",
        "    init_encoder_as_decoder_transpose=True,\n",
        "    apply_b_dec_to_input=True,\n",
        "    b_dec_init_method=\"zeros\",\n",
        ")\n",
        "partial_sae = BaseSAE(BaseSAEConfig.from_sae_runner_config(cfg))\n",
        "init_sae_to_match_model(partial_sae, toy_model, noise_level=0.1)\n",
        "\n",
        "train_toy_sae(partial_sae, toy_model, partial_cooccurrence_generator)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "from hedging_paper.toy_models.plotting import plot_sae_feat_cos_sims_seaborn\n",
        "\n",
        "plot_sae_feat_cos_sims_seaborn(\n",
        "    partial_sae,\n",
        "    toy_model,\n",
        "    \"Correlated features\",\n",
        "    height=1.5,\n",
        "    width=5,\n",
        "    save_path=\"plots/three_latent_correlated_features.pdf\",\n",
        "    one_based_indexing=True,\n",
        "    reorder_latents=True,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "hXCRZ0_5ESR8",
        "outputId": "c9376a01-1cdf-44fc-8c77-716fc233353a"
      },
      "outputs": [],
      "source": [
        "from hedging_paper.toy_models.train_toy_sae import train_toy_sae\n",
        "from hedging_paper.saes.base_sae import BaseSAE, BaseSAERunnerConfig\n",
        "from hedging_paper.saes.base_sae import BaseSAEConfig\n",
        "from hedging_paper.toy_models.initialization import init_sae_to_match_model\n",
        "\n",
        "\n",
        "cfg = BaseSAERunnerConfig(\n",
        "    context_size=500,\n",
        "    d_in=toy_model.embed.weight.shape[0],\n",
        "    d_sae=DEFAULT_D_SAE,\n",
        "``    l1_coefficient=1.0,\n",
        "    normalize_sae_decoder=False,\n",
        "    scale_sparsity_penalty_by_decoder_norm=True,\n",
        "    init_encoder_as_decoder_transpose=True,\n",
        "    apply_b_dec_to_input=True,\n",
        "    b_dec_init_method=\"zeros\",\n",
        ")\n",
        "partial_sae_high_l1 = BaseSAE(BaseSAEConfig.from_sae_runner_config(cfg))\n",
        "init_sae_to_match_model(partial_sae_high_l1, toy_model, noise_level=0.1)\n",
        "\n",
        "train_toy_sae(partial_sae_high_l1, toy_model, partial_cooccurrence_generator)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 317
        },
        "id": "JC2plBhOEwpc",
        "outputId": "56ae3c5f-7479-479a-e1f0-96c69d7d9295"
      },
      "outputs": [],
      "source": [
        "from hedging_paper.toy_models.plotting import plot_sae_feat_cos_sims_seaborn\n",
        "\n",
        "plot_sae_feat_cos_sims_seaborn(\n",
        "    partial_sae_high_l1,\n",
        "    toy_model,\n",
        "    \"Correlated features, high L1 penalty\",\n",
        "    height=1.5,\n",
        "    width=5,\n",
        "    save_path=\"plots/three_latent_correlated_features_high_l1.pdf\",\n",
        "    one_based_indexing=True,\n",
        "    reorder_latents=True,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dJU0ywSqU4cy"
      },
      "source": [
        "### What if we train a wide-enough SAE on these features?"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "from hedging_paper.toy_models.train_toy_sae import train_toy_sae\n",
        "from hedging_paper.saes.base_sae import BaseSAE, BaseSAERunnerConfig\n",
        "from hedging_paper.saes.base_sae import BaseSAEConfig\n",
        "from hedging_paper.toy_models.initialization import init_sae_to_match_model\n",
        "\n",
        "\n",
        "cfg = BaseSAERunnerConfig(\n",
        "    context_size=500,\n",
        "    d_in=toy_model.embed.weight.shape[0],\n",
        "    d_sae=toy_model.embed.weight.shape[1],\n",
        "    l1_coefficient=1e-2,\n",
        "    normalize_sae_decoder=False,\n",
        "    scale_sparsity_penalty_by_decoder_norm=True,\n",
        "    init_encoder_as_decoder_transpose=True,\n",
        "    apply_b_dec_to_input=True,\n",
        "    b_dec_init_method=\"zeros\",\n",
        ")\n",
        "partial_sae_wide = BaseSAE(BaseSAEConfig.from_sae_runner_config(cfg))\n",
        "init_sae_to_match_model(partial_sae_wide, toy_model, noise_level=0.25)\n",
        "\n",
        "train_toy_sae(partial_sae_wide, toy_model, partial_cooccurrence_generator)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "from hedging_paper.toy_models.plotting import plot_sae_feat_cos_sims_seaborn\n",
        "\n",
        "plot_sae_feat_cos_sims_seaborn(\n",
        "    partial_sae_wide,\n",
        "    toy_model,\n",
        "    \"Correlated features, full-width SAE\",\n",
        "    height=1.75,\n",
        "    width=5,\n",
        "    # show_values=True,\n",
        "    save_path=\"plots/three_latent_correlated_features_full_width.pdf\",\n",
        "    one_based_indexing=True,\n",
        "    reorder_latents=True,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gsjAUZ_raHpg"
      },
      "source": [
        "## What happens if the co-occurrence is flipped?\n",
        "\n",
        "This is very bad, but what happens if the feature fires more on its own then it co-occurs with the parent? Do we still see a messed up SAE latent?"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6DWNHNhJaGVX"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "from functools import partial\n",
        "from hedging_paper.toy_models.get_training_batch import get_training_batch\n",
        "\n",
        "# Features 1-3 fires with probability 0.25\n",
        "# Feature 4 fires with probability 0.1 if feature 3 is firing\n",
        "feat_probs = torch.tensor([0.25, 0.25, 0.25, 0.1])\n",
        "# Feature 4 fires with probability 0.2 if feature 3 is not firing\n",
        "solo_firing_prob = 0.2\n",
        "\n",
        "# set up the parent-child firing relationship\n",
        "def partial_cooccurence(feats: torch.Tensor):\n",
        "    solo_firings = torch.bernoulli(\n",
        "        torch.tensor(solo_firing_prob).unsqueeze(0).expand(feats.shape[0])\n",
        "    )\n",
        "\n",
        "    parent_feat_fires = feats[:, 2] == 1\n",
        "    feats[~parent_feat_fires, 3] = solo_firings[~parent_feat_fires]\n",
        "    return feats\n",
        "\n",
        "inv_partial_cooccurrence_generator = partial(get_training_batch, firing_probabilities=feat_probs, modify_firing_features=partial_cooccurence)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "mBo8YnnKalFy",
        "outputId": "35e0e63e-3af3-46d4-c60c-f5698eca035d"
      },
      "outputs": [],
      "source": [
        "from hedging_paper.toy_models.train_toy_sae import train_toy_sae\n",
        "from hedging_paper.saes.base_sae import BaseSAE, BaseSAERunnerConfig\n",
        "from hedging_paper.saes.base_sae import BaseSAEConfig\n",
        "\n",
        "\n",
        "cfg = BaseSAERunnerConfig(\n",
        "    context_size=500,\n",
        "    d_in=toy_model.embed.weight.shape[0],\n",
        "    d_sae=DEFAULT_D_SAE,\n",
        "    l1_coefficient=2e-2,\n",
        "    normalize_sae_decoder=False,\n",
        "    scale_sparsity_penalty_by_decoder_norm=True,\n",
        "    init_encoder_as_decoder_transpose=True,\n",
        "    apply_b_dec_to_input=True,\n",
        "    b_dec_init_method=\"zeros\",\n",
        ")\n",
        "partial_sae_inv = BaseSAE(BaseSAEConfig.from_sae_runner_config(cfg))\n",
        "\n",
        "train_toy_sae(partial_sae_inv, toy_model, inv_partial_cooccurrence_generator)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 317
        },
        "id": "xLVR7r-Kar-r",
        "outputId": "4ae0f47c-146c-43a8-deed-d11018ea9dba"
      },
      "outputs": [],
      "source": [
        "from hedging_paper.toy_models.plotting import plot_sae_feat_cos_sims_seaborn\n",
        "\n",
        "plot_sae_feat_cos_sims_seaborn(\n",
        "    partial_sae_inv,\n",
        "    toy_model,\n",
        "    \"Standard SAE - anti-correlated features\",\n",
        "    height=1.5,\n",
        "    width=5,\n",
        "    save_path=\"plots/three_latent_anti_correlated_features.pdf\",\n",
        "    one_based_indexing=True,\n",
        "    reorder_latents=True,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8PDT7-BYbC-B"
      },
      "source": [
        "The SAE latent now contains a negative component of the child feature, just due to partial co-occurrence! This is so very bad 😱"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 335
        },
        "id": "h9m-rVsFEnX2",
        "outputId": "b83c05fa-171d-4217-e329-d9ae1036084f"
      },
      "outputs": [],
      "source": [
        "\n",
        "from hedging_paper.toy_models.plotting import plot_b_dec_feat_cos_sims\n",
        "\n",
        "plot_b_dec_feat_cos_sims(partial_sae_inv, toy_model, \"Standard SAE - anti-correlated features\", show_values=True)\n",
        "print(f\"SAE b_dec magnitude: {partial_sae_inv.b_dec.norm():.3f}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "from hedging_paper.toy_models.train_toy_sae import train_toy_sae\n",
        "from hedging_paper.saes.base_sae import BaseSAE, BaseSAERunnerConfig\n",
        "from hedging_paper.saes.base_sae import BaseSAEConfig\n",
        "\n",
        "\n",
        "cfg = BaseSAERunnerConfig(\n",
        "    context_size=500,\n",
        "    d_in=toy_model.embed.weight.shape[0],\n",
        "    d_sae=toy_model.embed.weight.shape[1],\n",
        "    l1_coefficient=2e-2,\n",
        "    normalize_sae_decoder=False,\n",
        "    scale_sparsity_penalty_by_decoder_norm=True,\n",
        "    init_encoder_as_decoder_transpose=True,\n",
        "    apply_b_dec_to_input=True,\n",
        "    b_dec_init_method=\"zeros\",\n",
        ")\n",
        "partial_sae_inv_full = BaseSAE(BaseSAEConfig.from_sae_runner_config(cfg))\n",
        "\n",
        "train_toy_sae(partial_sae_inv_full, toy_model, inv_partial_cooccurrence_generator)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "from hedging_paper.toy_models.plotting import plot_sae_feat_cos_sims_seaborn\n",
        "\n",
        "plot_sae_feat_cos_sims_seaborn(\n",
        "    partial_sae_inv_full,\n",
        "    toy_model,\n",
        "    \"Standard SAE - anti-correlated features\",\n",
        "    height=1.75,\n",
        "    width=5,\n",
        "    save_path=\"plots/three_latent_anti_correlated_features_full_width.pdf\",\n",
        "    one_based_indexing=True,\n",
        "    reorder_latents=True,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jqpkwhMBFQhZ"
      },
      "source": [
        "This works because the SAE decoder bias still encodes a smaller magnitude version of feature 1, so the SAE latent encoding a negative component of feature 1 reduces the magnitude of the hedging when the SAE latent fires. The SAE is abusing the knowledge that feature 0 firing means feature 1 is less likely to be firing to reduce the MSE loss penalty for having this always-active component of feature 1 in the SAE decoder bias."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### What if the SAE is wide enough to represent all the features?"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "from hedging_paper.toy_models.train_toy_sae import train_toy_sae\n",
        "from hedging_paper.saes.base_sae import BaseSAE, BaseSAERunnerConfig\n",
        "from hedging_paper.saes.base_sae import BaseSAEConfig\n",
        "from hedging_paper.toy_models.initialization import init_sae_to_match_model\n",
        "\n",
        "\n",
        "cfg = BaseSAERunnerConfig(\n",
        "    context_size=500,\n",
        "    d_in=toy_model.embed.weight.shape[0],\n",
        "    d_sae=toy_model.embed.weight.shape[1],\n",
        "    l1_coefficient=2e-2,\n",
        "    normalize_sae_decoder=False,\n",
        "    scale_sparsity_penalty_by_decoder_norm=True,\n",
        "    init_encoder_as_decoder_transpose=True,\n",
        "    apply_b_dec_to_input=True,\n",
        "    b_dec_init_method=\"zeros\",\n",
        ")\n",
        "partial_sae_inv_wide = BaseSAE(BaseSAEConfig.from_sae_runner_config(cfg))\n",
        "init_sae_to_match_model(partial_sae_inv_wide, toy_model, noise_level=0.25)\n",
        "\n",
        "train_toy_sae(partial_sae_inv_wide, toy_model, inv_partial_cooccurrence_generator)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "from hedging_paper.toy_models.plotting import plot_sae_feat_cos_sims_seaborn\n",
        "\n",
        "plot_sae_feat_cos_sims_seaborn(\n",
        "    partial_sae_inv_wide,\n",
        "    toy_model,\n",
        "    \"Full-width SAE - anti-correlated features\",\n",
        "    height=1.75,\n",
        "    width=5,\n",
        "    save_path=\"plots/three_latent_anti_correlated_features_wide.pdf\",\n",
        "    one_based_indexing=True,\n",
        "    reorder_latents=True,\n",
        ")"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [
        "bDeRplIZ_Cxq",
        "oZfr2ude2HEi"
      ],
      "provenance": []
    },
    "kernelspec": {
      "display_name": "hedging-paper-unqjxfXk-py3.11",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.4"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
