{
  "cells": [
    {
      "cell_type": "markdown",
      "id": "9084cf98",
      "metadata": {},
      "source": [
        "# Toy Model Experiments: JumpReLU edition\n",
        "\n",
        "We focus on BatchTopK SAEs in the paper since its easier to directly control their L0, but we show here that the same findings still hold for JumpReLU SAEs as well. We just run the nth decoder projection experiment here, showing that JumpReLU SAEs also minimize this metric at the correct L0."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "8f1a13f9",
      "metadata": {},
      "outputs": [],
      "source": [
        "from functools import partial\n",
        "from pathlib import Path\n",
        "import torch\n",
        "import pandas as pd\n",
        "from tqdm import tqdm\n",
        "import seaborn as sns\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "from sparse_but_wrong.toy_models.get_training_batch import generate_random_correlation_matrix, get_training_batch\n",
        "from sparse_but_wrong.toy_models.toy_model import ToyModel\n",
        "from sparse_but_wrong.toy_models.plotting import SEABORN_RC_CONTEXT\n",
        "from sparse_but_wrong.util import DEFAULT_DEVICE\n",
        "\n",
        "tqdm._instances.clear()  # type: ignore\n",
        "\n",
        "feat_probs = 0.345 * (50 - torch.arange(50) - 1) / 50 + 0.05\n",
        "\n",
        "if Path(\"correlations.pt\").exists():\n",
        "    print(\"Loading correlations from disk\")\n",
        "    correlations = torch.load(\"correlations.pt\")\n",
        "else:\n",
        "    correlations = generate_random_correlation_matrix(\n",
        "        correlation_strength_range=(0.3, 0.9),\n",
        "        num_features=50,\n",
        "        seed=42,\n",
        "    )\n",
        "    torch.save(correlations, \"correlations.pt\")\n",
        "\n",
        "indices = torch.arange(50) + 1\n",
        "df = pd.DataFrame({\n",
        "    \"P_i\": feat_probs,\n",
        "    \"feature\": map(str, indices.tolist()),\n",
        "})\n",
        "\n",
        "\n",
        "if Path(\"toy_model.pt\").exists():\n",
        "    print(\"Loading toy model from disk\")\n",
        "    toy_model = torch.load(\"toy_model.pt\", weights_only=False)\n",
        "else:\n",
        "    toy_model = ToyModel(num_feats=50, hidden_dim=100).to(DEFAULT_DEVICE)\n",
        "    torch.save(toy_model, \"toy_model.pt\")\n",
        "\n",
        "\n",
        "generate_batch = partial(\n",
        "    get_training_batch,\n",
        "    firing_probabilities=feat_probs,\n",
        "    std_firing_magnitudes=torch.ones_like(feat_probs) * 0.15,\n",
        "    correlation_matrix=correlations,\n",
        ")\n",
        "\n",
        "Path(\"plots/toy_setup\").mkdir(parents=True, exist_ok=True)\n",
        "\n",
        "plt.rcParams.update({\"figure.dpi\": 150})\n",
        "with plt.rc_context(SEABORN_RC_CONTEXT):\n",
        "    plt.figure(figsize=(3, 2))\n",
        "    sns.barplot(data=df, x=\"feature\", y=\"P_i\")\n",
        "    plt.xlabel(\"Feature\")\n",
        "    plt.ylabel(\"$P_i$\")\n",
        "    plt.title(\"Feature firing probabilities $P_i$\")\n",
        "    # Increase tick spacing to prevent overlapping\n",
        "    plt.xticks(range(0, len(df), 5), [str(i) for i in range(0, len(df), 5)])  # Show every 5th tick, 0-indexed\n",
        "    plt.tight_layout()\n",
        "    plt.savefig(\"plots/toy_setup/toy_model_feature_firing_probabilities.pdf\")\n",
        "    plt.show()\n",
        "\n",
        "plt.rcParams.update({\"figure.dpi\": 150})\n",
        "with plt.rc_context(SEABORN_RC_CONTEXT):\n",
        "    plt.figure(figsize=(2.5, 2))\n",
        "    sns.heatmap(correlations, cmap=\"RdBu\", center=0, vmin=-1, vmax=1)\n",
        "    plt.xlabel(\"Feature\")\n",
        "    plt.ylabel(\"Feature\")\n",
        "    plt.title(\"Feature correlation matrix\")\n",
        "    # Increase tick spacing to prevent overlapping\n",
        "    plt.xticks(range(0, len(correlations), 10), [str(i) for i in range(0, len(correlations), 10)])  # Show every 10th tick, 0-indexed\n",
        "    plt.yticks(range(0, len(correlations), 10), [str(i) for i in range(0, len(correlations), 10)])  # Show every 10th tick, 0-indexed\n",
        "    # plt.tight_layout()\n",
        "    plt.savefig(\"plots/toy_setup/toy_model_correlation_matrix.pdf\")\n",
        "    plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "8aef584d",
      "metadata": {},
      "source": [
        "## Finding the True L0\n",
        "\n",
        "Next, we calculate the true L0 for this dataset (spoiler: it's ~11)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "2d2f1ea7",
      "metadata": {},
      "outputs": [],
      "source": [
        "sample = generate_batch(100_000)\n",
        "true_l0 = (sample > 0).float().sum(dim=-1).mean()\n",
        "print(f\"True L0: {true_l0}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "f14d046f",
      "metadata": {},
      "source": [
        "# Can we detect when we're at the correct L0?\n",
        "\n",
        "Let's train SAEs at a range of L0, ranging from below the correct L0 to above the correct L0"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "ca71be7a",
      "metadata": {},
      "outputs": [],
      "source": [
        "from sae_lens import JumpReLUTrainingSAE, JumpReLUTrainingSAEConfig\n",
        "from sparse_but_wrong.toy_models.train_toy_sae import train_toy_sae\n",
        "from sparse_but_wrong.toy_models.eval_sae import eval_sae\n",
        "\n",
        "saes_by_k = {}\n",
        "for l1 in [0.05, 0.07, 0.09, 0.1, 0.3, 0.5, 0.75, 1.0, 1.2, 1.3, 1.4, 1.5]:\n",
        "    Path(f\"jr_saes_by_l1\").mkdir(parents=True, exist_ok=True)\n",
        "    if Path(f\"jr_saes_by_l1/{l1}\").exists():\n",
        "        print(f\"Loading SAE with l1={l1} from disk\")\n",
        "        sae = JumpReLUTrainingSAE.load_from_disk(f\"jr_saes_by_l1/{l1}\")\n",
        "    else:\n",
        "        print(f\"Training SAE with l1={l1}\")\n",
        "        cfg = JumpReLUTrainingSAEConfig(\n",
        "            l0_coefficient=l1,\n",
        "            jumprelu_bandwidth=2.0,\n",
        "            jumprelu_init_threshold=0.1,\n",
        "            jumprelu_sparsity_loss_mode=\"tanh\",\n",
        "            l0_warm_up_steps=10_000,\n",
        "            normalize_activations=\"expected_average_only_in\",\n",
        "            d_in=toy_model.embed.weight.shape[0],\n",
        "            d_sae=50,\n",
        "        )\n",
        "        sae = JumpReLUTrainingSAE(cfg)\n",
        "        train_toy_sae(sae, toy_model, generate_batch)\n",
        "        sae.save_model(f\"jr_saes_by_l1/{l1}\")\n",
        "\n",
        "    stats = eval_sae(sae, toy_model, generate_batch)\n",
        "    print(f\"SAE L0: {stats.sae_l0}, dead latents: {stats.dead_features}, L1 coefficient: {l1:.2f}\")\n",
        "    saes_by_k[stats.sae_l0] = sae\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "d3255bba",
      "metadata": {},
      "outputs": [],
      "source": [
        "from pathlib import Path\n",
        "from typing import Callable\n",
        "import torch\n",
        "import pandas as pd\n",
        "import matplotlib.pyplot as plt\n",
        "import seaborn as sns\n",
        "from tqdm import tqdm\n",
        "from sae_lens import TrainingSAE\n",
        "from sparse_but_wrong.toy_models.plotting import SEABORN_RC_CONTEXT\n",
        "\n",
        "sorted_dec_hidden_pres = {}\n",
        "\n",
        "def calc_l0_dec_thresholds(sae: TrainingSAE, model, generate_batch: Callable[[int], torch.Tensor]) -> list[float]:\n",
        "    inputs = model.embed(generate_batch(100_000))\n",
        "    hidden_pre_dec = (inputs - sae.b_dec) @ sae.W_dec.T\n",
        "    sorted_hidden_pre_dec = hidden_pre_dec.flatten().sort(descending=True).values\n",
        "    k_inds = torch.arange(hidden_pre_dec.shape[-1]) * hidden_pre_dec.shape[0]\n",
        "    return sorted_hidden_pre_dec[k_inds].tolist()\n",
        "\n",
        "thresholds = {}\n",
        "for k, sae in tqdm(saes_by_k.items()):\n",
        "    thresholds[k] = calc_l0_dec_thresholds(sae, toy_model, generate_batch)\n",
        "\n",
        "data = []\n",
        "for l0, l0_thresholds in thresholds.items():\n",
        "    row = {\"l0\": l0}\n",
        "    for i, threshold in enumerate(l0_thresholds):\n",
        "        row[f\"l0_dec_{i}\"] = threshold\n",
        "    data.append(row)\n",
        "df = pd.DataFrame(data)\n",
        "\n",
        "Path(\"plots/dpn_jumprelu\").mkdir(parents=True, exist_ok=True)\n",
        "for k in [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25]:\n",
        "    plt.rcParams.update({\"figure.dpi\": 150})\n",
        "    sns.set_theme()\n",
        "    with plt.rc_context(SEABORN_RC_CONTEXT):\n",
        "        plt.figure(figsize=(3, 1.5))\n",
        "        sns.lineplot(data=df, x=\"l0\", y=f\"l0_dec_{k}\")\n",
        "        \n",
        "        # Add vertical line at true L0\n",
        "        plt.axvline(x=11, color='red', linestyle='--', linewidth=0.5, alpha=0.7, label='True L0')\n",
        "        \n",
        "        plt.title(f\"$n={k}$ Decoder Projection vs SAE L0\")\n",
        "        plt.xlabel(\"SAE L0\")\n",
        "        plt.ylabel(\"$s_n^{dec}$\")\n",
        "        plt.grid(True, alpha=0.3)\n",
        "        \n",
        "        plt.tight_layout()\n",
        "        plt.savefig(f\"plots/dpn_jumprelu/dp_{k}.pdf\")\n",
        "        plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "f91c411a",
      "metadata": {},
      "outputs": [],
      "source": [
        "from sparse_but_wrong.toy_models.plotting import SEABORN_RC_CONTEXT\n",
        "import matplotlib.pyplot as plt\n",
        "import seaborn as sns\n",
        "\n",
        "data = []\n",
        "for k, sae in tqdm(saes_by_k.items()):\n",
        "    data.append({\"l0\": k, \"l0_coeff\": sae.cfg.l0_coefficient})\n",
        "df = pd.DataFrame(data)\n",
        "\n",
        "plt.rcParams.update({\"figure.dpi\": 150})\n",
        "with plt.rc_context(SEABORN_RC_CONTEXT):\n",
        "    plt.figure(figsize=(3, 2))\n",
        "    sns.scatterplot(data=df, x=\"l0\", y=\"l0_coeff\")\n",
        "\n",
        "    # Add vertical line at true L0\n",
        "    plt.axvline(x=11, color='red', linestyle='--', linewidth=0.5, alpha=0.7, label='True L0')\n",
        "    plt.xlabel(\"SAE L0\")\n",
        "    plt.ylabel(\"L0 Coefficient\")\n",
        "    plt.title(\"SAE L0 vs L0 Coefficient for JumpReLU SAEs\")\n",
        "    plt.tight_layout()\n",
        "    plt.savefig(\"plots/toy_l0_jumprelu/l0_vs_l0_coeff_jumprelu.pdf\")\n",
        "    plt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": ".venv",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.4"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}
