{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Balancing Hedging and Absorption in Matryoshka SAEs\n",
    "\n",
    "Hedging and Absoption have opposite effects on a the encoder of an SAE. Absorption causes the latent tracking the parent feature to merge negative components of child features. Hedging does the opposite, and merges in positive components of child features into the parent latent.\n",
    "\n",
    "As a result, it's actually beneficial to allow the outer layers of a Matryoshka SAE to exert some absorption pressure on the inner layers to reduce the amount of hedging and get closer to learning the true features.\n",
    "\n",
    "In this notebook, we explore this idea in a toy model."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "source": [
    "## Create toy model and hierarchical feature generator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from hedging_paper.toy_models.toy_model import ToyModel\n",
    "from hedging_paper.toy_models.tree_feature_generator import TreeFeatureGenerator\n",
    "\n",
    "# Create the toy model and hierarchical features\n",
    "\n",
    "DEFAULT_D_IN = 50\n",
    "\n",
    "DEFAULT_D_SAE = 4\n",
    "DEFAULT_NUM_FEATS = 4\n",
    "\n",
    "toy_model = ToyModel(num_feats=DEFAULT_NUM_FEATS, hidden_dim=DEFAULT_D_IN)\n",
    "\n",
    "children1 = []\n",
    "children1.append(TreeFeatureGenerator(0.15))\n",
    "children1.append(TreeFeatureGenerator(0.15))\n",
    "children1.append(TreeFeatureGenerator(0.15))\n",
    "children1.append(TreeFeatureGenerator(0.55, is_read_out=False))\n",
    "parent1 = TreeFeatureGenerator(0.25, children1, mutually_exclusive_children=False)\n",
    "\n",
    "feat_generator = TreeFeatureGenerator(1.0, [parent1], is_read_out=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Balancing the inner and outer Matryoshka layers\n",
    "\n",
    "Can we find a loss multiplier that will work to balance these layers?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from hedging_paper.saes.matryoshka_sae import MatryoshkaSAE, MatryoshkaSAEConfig, MatryoshkaSAERunnerConfig\n",
    "from hedging_paper.toy_models.initialization import init_sae_to_match_model\n",
    "from hedging_paper.toy_models.train_toy_sae import train_toy_sae\n",
    "\n",
    "cfg = MatryoshkaSAERunnerConfig(\n",
    "    d_in=toy_model.embed.weight.shape[0],\n",
    "    d_sae=DEFAULT_D_SAE,\n",
    "    l1_coefficient=3e-2,\n",
    "    normalize_sae_decoder=False,\n",
    "    scale_sparsity_penalty_by_decoder_norm=True,\n",
    "    init_encoder_as_decoder_transpose=True,\n",
    "    apply_b_dec_to_input=True,\n",
    "    b_dec_init_method=\"zeros\",\n",
    "    matryoshka_steps=[1, DEFAULT_D_SAE],\n",
    "    skip_outer_loss=True,\n",
    "    # matryoshka_inner_loss_multipliers=[10.0],\n",
    "    use_delta_loss=True,\n",
    "    # reconstruction_loss=\"L2\",\n",
    "    # matryoshka_reconstruction_loss=\"L2\",\n",
    ")\n",
    "sae_full_mat = MatryoshkaSAE(MatryoshkaSAEConfig.from_sae_runner_config(cfg))\n",
    "init_sae_to_match_model(\n",
    "    sae_full_mat,\n",
    "    toy_model,\n",
    "    noise_level=0.0,\n",
    "    # feature_ordering=torch.tensor([0, 4, 1, 2, 3, 5, 6]),\n",
    "    feature_ordering=torch.tensor([0, 1, 2, 3]),\n",
    ")\n",
    "\n",
    "\n",
    "# NOTE: occasionaly this gets stuck in poor local minima. If this happens, try rerunning and it should converge properly.\n",
    "train_toy_sae(sae_full_mat, toy_model, feat_generator.sample)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from hedging_paper.toy_models.plotting import plot_sae_feat_cos_sims_seaborn\n",
    "\n",
    "plot_sae_feat_cos_sims_seaborn(\n",
    "    sae_full_mat,\n",
    "    toy_model,\n",
    "    title=r\"Detached Matryoshka SAE ($\\beta = \\infty$)\",\n",
    "    show_values=False,\n",
    "    width=3,\n",
    "    height=1.5,\n",
    "    save_path=\"plots/toy_sae_mat_detached.pdf\",\n",
    "    one_based_indexing=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from hedging_paper.saes.matryoshka_sae import MatryoshkaSAE, MatryoshkaSAEConfig, MatryoshkaSAERunnerConfig\n",
    "from hedging_paper.toy_models.initialization import init_sae_to_match_model\n",
    "from hedging_paper.toy_models.train_toy_sae import train_toy_sae\n",
    "\n",
    "cfg = MatryoshkaSAERunnerConfig(\n",
    "    d_in=toy_model.embed.weight.shape[0],\n",
    "    d_sae=DEFAULT_D_SAE,\n",
    "    l1_coefficient=3e-2,\n",
    "    normalize_sae_decoder=False,\n",
    "    scale_sparsity_penalty_by_decoder_norm=True,\n",
    "    init_encoder_as_decoder_transpose=True,\n",
    "    apply_b_dec_to_input=True,\n",
    "    b_dec_init_method=\"zeros\",\n",
    "    matryoshka_steps=[DEFAULT_D_SAE],\n",
    "    skip_outer_loss=True,\n",
    "    # matryoshka_inner_loss_multipliers=[10.0],\n",
    "    # use_delta_loss=True,\n",
    "    # reconstruction_loss=\"L2\",\n",
    "    # matryoshka_reconstruction_loss=\"L2\",\n",
    ")\n",
    "sae_full_abs = MatryoshkaSAE(MatryoshkaSAEConfig.from_sae_runner_config(cfg))\n",
    "init_sae_to_match_model(\n",
    "    sae_full_abs,\n",
    "    toy_model,\n",
    "    noise_level=0.0,\n",
    "    # feature_ordering=torch.tensor([0, 4, 1, 2, 3, 5, 6]),\n",
    "    feature_ordering=torch.tensor([0, 1, 2, 3]),\n",
    ")\n",
    "\n",
    "train_toy_sae(sae_full_abs, toy_model, feat_generator.sample)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from hedging_paper.toy_models.plotting import plot_sae_feat_cos_sims_seaborn\n",
    "\n",
    "plot_sae_feat_cos_sims_seaborn(\n",
    "    sae_full_abs,\n",
    "    toy_model,\n",
    "    title=r\"Standard SAE ($\\beta = 0$)\",\n",
    "    show_values=False,\n",
    "    width=3,\n",
    "    height=1.5,\n",
    "    save_path=\"plots/toy_sae_abs.pdf\",\n",
    "    one_based_indexing=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from hedging_paper.saes.matryoshka_sae import MatryoshkaSAE, MatryoshkaSAEConfig, MatryoshkaSAERunnerConfig\n",
    "from hedging_paper.toy_models.initialization import init_sae_to_match_model\n",
    "from hedging_paper.toy_models.train_toy_sae import train_toy_sae\n",
    "\n",
    "cfg = MatryoshkaSAERunnerConfig(\n",
    "    d_in=toy_model.embed.weight.shape[0],\n",
    "    d_sae=DEFAULT_D_SAE,\n",
    "    l1_coefficient=3e-2,\n",
    "    normalize_sae_decoder=False,\n",
    "    scale_sparsity_penalty_by_decoder_norm=True,\n",
    "    init_encoder_as_decoder_transpose=True,\n",
    "    apply_b_dec_to_input=True,\n",
    "    b_dec_init_method=\"zeros\",\n",
    "    matryoshka_steps=[1, DEFAULT_D_SAE],\n",
    "    skip_outer_loss=True,\n",
    "    matryoshka_inner_loss_multipliers=[0.25, 1.0],\n",
    "    # use_delta_loss=True,\n",
    "    # reconstruction_loss=\"L2\",\n",
    "    # matryoshka_reconstruction_loss=\"L2\",\n",
    ")\n",
    "sae_bal_mat = MatryoshkaSAE(MatryoshkaSAEConfig.from_sae_runner_config(cfg))\n",
    "init_sae_to_match_model(\n",
    "    sae_bal_mat,\n",
    "    toy_model,\n",
    "    noise_level=0.0,\n",
    "    # feature_ordering=torch.tensor([0, 4, 1, 2, 3, 5, 6]),\n",
    "    feature_ordering=torch.tensor([0, 1, 2, 3]),\n",
    ")\n",
    "\n",
    "train_toy_sae(sae_bal_mat, toy_model, feat_generator.sample)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from hedging_paper.toy_models.plotting import plot_sae_feat_cos_sims_seaborn\n",
    "from hedging_paper.toy_models.plotting import plot_sae_feat_cos_sims\n",
    "\n",
    "plot_sae_feat_cos_sims(\n",
    "    sae_bal_mat,\n",
    "    toy_model,\n",
    "    \"Balanced Matryoshka SAE\",\n",
    "    show_values=False,\n",
    ")\n",
    "plot_sae_feat_cos_sims_seaborn(\n",
    "    sae_bal_mat,\n",
    "    toy_model,\n",
    "    title=r\"Balanced Matryoshka SAE ($\\beta = 0.25$)\",\n",
    "    show_values=False,\n",
    "    width=3,\n",
    "    height=1.5,\n",
    "    save_path=\"plots/toy_mat_balanced.pdf\",\n",
    "    one_based_indexing=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Is it always possible to balance hedging and absorption?\n",
    "\n",
    "In the above toy models, all child features fire with probability 0.15. What if they fire with different probabilities?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from hedging_paper.toy_models.toy_model import ToyModel\n",
    "from hedging_paper.toy_models.tree_feature_generator import TreeFeatureGenerator\n",
    "\n",
    "# Create the toy model and hierarchical features\n",
    "\n",
    "\n",
    "children1 = []\n",
    "children1.append(TreeFeatureGenerator(0.02))\n",
    "children1.append(TreeFeatureGenerator(0.15))\n",
    "children1.append(TreeFeatureGenerator(0.5))\n",
    "children1.append(TreeFeatureGenerator(0.55, is_read_out=False))\n",
    "parent1 = TreeFeatureGenerator(0.25, children1, mutually_exclusive_children=False)\n",
    "\n",
    "feat_generator_unbalanced = TreeFeatureGenerator(1.0, [parent1], is_read_out=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from hedging_paper.saes.matryoshka_sae import MatryoshkaSAE, MatryoshkaSAEConfig, MatryoshkaSAERunnerConfig\n",
    "from hedging_paper.toy_models.initialization import init_sae_to_match_model\n",
    "from hedging_paper.toy_models.train_toy_sae import train_toy_sae\n",
    "\n",
    "cfg = MatryoshkaSAERunnerConfig(\n",
    "    d_in=toy_model.embed.weight.shape[0],\n",
    "    d_sae=DEFAULT_D_SAE,\n",
    "    l1_coefficient=3e-2,\n",
    "    normalize_sae_decoder=False,\n",
    "    scale_sparsity_penalty_by_decoder_norm=True,\n",
    "    init_encoder_as_decoder_transpose=True,\n",
    "    apply_b_dec_to_input=True,\n",
    "    b_dec_init_method=\"zeros\",\n",
    "    matryoshka_steps=[1, DEFAULT_D_SAE],\n",
    "    skip_outer_loss=True,\n",
    "    matryoshka_inner_loss_multipliers=[0.17, 1.0],\n",
    ")\n",
    "sae_unbal_mat = MatryoshkaSAE(MatryoshkaSAEConfig.from_sae_runner_config(cfg))\n",
    "init_sae_to_match_model(\n",
    "    sae_unbal_mat,\n",
    "    toy_model,\n",
    "    noise_level=0.0,\n",
    "    feature_ordering=torch.tensor([0, 1, 2, 3]),\n",
    ")\n",
    "\n",
    "\n",
    "train_toy_sae(sae_unbal_mat, toy_model, feat_generator_unbalanced.sample)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from hedging_paper.toy_models.plotting import plot_sae_feat_cos_sims_seaborn\n",
    "from hedging_paper.toy_models.plotting import plot_sae_feat_cos_sims\n",
    "\n",
    "plot_sae_feat_cos_sims(\n",
    "    sae_unbal_mat,\n",
    "    toy_model,\n",
    "    \"Unbalanceable Matryoshka SAE\",\n",
    "    show_values=False,\n",
    ")\n",
    "plot_sae_feat_cos_sims_seaborn(\n",
    "    sae_unbal_mat,\n",
    "    toy_model,\n",
    "    title=r\"Unbalanceable Matryoshka SAE ($\\beta = 0.17$)\",\n",
    "    show_values=False,\n",
    "    width=3,\n",
    "    height=1.5,\n",
    "    save_path=\"plots/toy_mat_unbalanceable.pdf\",\n",
    "    one_based_indexing=True,\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "sae-rethink-3xX6iC0B-py3.11",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
