{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ddb4c5e1",
   "metadata": {},
   "source": [
    "# Intuition plots\n",
    "\n",
    "This notebook is just used to generate plots for the intuition figure of the paper."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fcba5d75",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from sparse_but_wrong.toy_models.plotting import SEABORN_RC_CONTEXT\n",
    "\n",
    "def create_gaussian(x, mu, sigma):\n",
    "    return np.exp(-(x - mu) ** 2 / (2 * sigma ** 2)) / (sigma * np.sqrt(2 * np.pi))\n",
    "\n",
    "\n",
    "x_range = np.linspace(-10, 30, 1000)\n",
    "correct_model_vals = 2 * create_gaussian(x_range, 0, 1) + create_gaussian(x_range, 15, 5)\n",
    "incorrect_model_vals = 2 * create_gaussian(x_range, 0, 2) + create_gaussian(x_range, 15, 7)\n",
    "\n",
    "Path(\"plots\").mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "\n",
    "plt.rcParams.update({\"figure.dpi\": 150})\n",
    "with plt.rc_context(SEABORN_RC_CONTEXT):\n",
    "    plt.figure(figsize=(5, 2))\n",
    "    \n",
    "    # Function to find x-threshold for rightmost 40% of area under curve\n",
    "    def find_area_threshold(x_vals, y_vals, target_fraction=0.4):\n",
    "        import builtins\n",
    "        total_area = np.trapz(y_vals, x_vals)\n",
    "        target_area = total_area * target_fraction\n",
    "        \n",
    "        # Start from the right and integrate backwards until we reach target area\n",
    "        for i in builtins.range(len(x_vals) - 1, 0, -1):\n",
    "            area_from_right = np.trapz(y_vals[i:], x_vals[i:])\n",
    "            if area_from_right >= target_area:\n",
    "                return x_vals[i]\n",
    "        return x_vals[0]\n",
    "    \n",
    "    # Calculate thresholds for each curve\n",
    "    correct_threshold = find_area_threshold(x_range, correct_model_vals)\n",
    "    incorrect_threshold = find_area_threshold(x_range, incorrect_model_vals)\n",
    "    \n",
    "    # Plot lines\n",
    "    line1 = plt.plot(x_range, correct_model_vals, label=\"Disentangled SAE\", linewidth=0.5)\n",
    "    line2 = plt.plot(x_range, incorrect_model_vals, label=\"SAE mixing correlated features\", linewidth=0.5)\n",
    "    \n",
    "    # Get colors from the lines for consistent coloring\n",
    "    color1 = line1[0].get_color()\n",
    "    color2 = line2[0].get_color()\n",
    "    \n",
    "    # Fill entire area with light transparency\n",
    "    plt.fill_between(x_range, correct_model_vals, alpha=0.3, color=color1)\n",
    "    plt.fill_between(x_range, incorrect_model_vals, alpha=0.3, color=color2)\n",
    "    \n",
    "    # Fill rightmost 40% by area with darker color\n",
    "    plt.fill_between(x_range, correct_model_vals, \n",
    "                     where=(x_range >= correct_threshold), \n",
    "                     alpha=0.6, color=color1, interpolate=True)\n",
    "    plt.fill_between(x_range, incorrect_model_vals,\n",
    "                     where=(x_range >= incorrect_threshold),\n",
    "                     alpha=0.6, color=color2, interpolate=True)\n",
    "    # Add arrows pointing to threshold points\n",
    "    # Find y-values at threshold points\n",
    "    correct_idx = np.argmin(np.abs(x_range - correct_threshold))\n",
    "    incorrect_idx = np.argmin(np.abs(x_range - incorrect_threshold))\n",
    "    correct_y = correct_model_vals[correct_idx]\n",
    "    incorrect_y = incorrect_model_vals[incorrect_idx]\n",
    "    \n",
    "    # Add arrows with labels\n",
    "    plt.annotate('$s_n^{dec}$', \n",
    "                xy=(correct_threshold, correct_y), \n",
    "                xytext=(correct_threshold + 4, correct_y + 0.15),\n",
    "                arrowprops=dict(arrowstyle='->', color=color1, lw=0.5),\n",
    "                fontsize=10, ha='center', color=color1)\n",
    "    \n",
    "    plt.annotate('$s_n^{dec}$', \n",
    "                xy=(incorrect_threshold, incorrect_y), \n",
    "                xytext=(incorrect_threshold + 4, incorrect_y + 0.15),\n",
    "                arrowprops=dict(arrowstyle='->', color=color2, lw=0.5),\n",
    "                fontsize=10, ha='center', color=color2)\n",
    "    \n",
    "    # plt.xlabel(\"Decoder projection on input activations\")\n",
    "    plt.title(\"Idealized decoder projection histogram\")\n",
    "    legend = plt.legend()\n",
    "    for line in legend.get_lines():\n",
    "        line.set_linewidth(1)\n",
    "    # Customize legend border\n",
    "    legend.get_frame().set_edgecolor('lightgray')\n",
    "    legend.get_frame().set_linewidth(0.5)\n",
    "\n",
    "    # Apply axis customization outside the seaborn context\n",
    "    ax = plt.gca()\n",
    "    ax.grid(False)\n",
    "    ax.set_xticks([0])\n",
    "    ax.set_xticklabels(['0'])\n",
    "    ax.set_yticks([])\n",
    "    ax.spines['top'].set_visible(False)\n",
    "    ax.spines['right'].set_visible(False)\n",
    "    ax.spines['bottom'].set_visible(False)\n",
    "    ax.spines['left'].set_visible(False)\n",
    "    ax.set_facecolor('white')\n",
    "    ax.tick_params(axis='x', which='major', length=4, width=0.5, color='black')\n",
    "    ax.tick_params(axis='y', which='both', left=False, labelleft=False)\n",
    "    ax.set_ylim(bottom=0)\n",
    "\n",
    "    plt.savefig(\"plots/intuition_plot.pdf\")\n",
    "    \n",
    "    plt.show()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c019a87f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from sparse_but_wrong.toy_models.plotting import SEABORN_RC_CONTEXT\n",
    "\n",
    "def create_gaussian(x, mu, sigma):\n",
    "    return np.exp(-(x - mu) ** 2 / (2 * sigma ** 2)) / (sigma * np.sqrt(2 * np.pi))\n",
    "\n",
    "\n",
    "x_range = np.linspace(-10, 30, 1000)\n",
    "correct_model_vals = 2 * create_gaussian(x_range, 0, 1) + create_gaussian(x_range, 15, 5)\n",
    "incorrect_model_vals = 2 * create_gaussian(x_range, 0, 2) + create_gaussian(x_range, 15, 7)\n",
    "\n",
    "Path(\"plots\").mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "\n",
    "plt.rcParams.update({\"figure.dpi\": 150})\n",
    "with plt.rc_context(SEABORN_RC_CONTEXT):\n",
    "    plt.figure(figsize=(5, 2))\n",
    "    \n",
    "    # Function to find x-threshold for rightmost 40% of area under curve\n",
    "    def find_area_threshold(x_vals, y_vals, target_fraction=0.4):\n",
    "        import builtins\n",
    "        total_area = np.trapz(y_vals, x_vals)\n",
    "        target_area = total_area * target_fraction\n",
    "        \n",
    "        # Start from the right and integrate backwards until we reach target area\n",
    "        for i in builtins.range(len(x_vals) - 1, 0, -1):\n",
    "            area_from_right = np.trapz(y_vals[i:], x_vals[i:])\n",
    "            if area_from_right >= target_area:\n",
    "                return x_vals[i]\n",
    "        return x_vals[0]\n",
    "    \n",
    "    # Calculate thresholds for each curve\n",
    "    correct_threshold = find_area_threshold(x_range, correct_model_vals)\n",
    "    incorrect_threshold = find_area_threshold(x_range, incorrect_model_vals)\n",
    "    \n",
    "    # Plot lines\n",
    "    line1 = plt.plot(x_range, correct_model_vals, label=\"Disentangled SAE\", linewidth=0.5)\n",
    "    line2 = plt.plot(x_range, incorrect_model_vals, label=\"SAE mixing correlated features\", linewidth=0.5)\n",
    "    \n",
    "    # Get colors from the lines for consistent coloring\n",
    "    color1 = line1[0].get_color()\n",
    "    color2 = line2[0].get_color()\n",
    "    \n",
    "    # Fill entire area with light transparency\n",
    "    plt.fill_between(x_range, correct_model_vals, alpha=0.3, color=color1)\n",
    "    plt.fill_between(x_range, incorrect_model_vals, alpha=0.3, color=color2)\n",
    "    \n",
    "    \n",
    "    # plt.xlabel(\"Decoder projection on input activations\")\n",
    "    plt.title(\"Idealized decoder projection histogram\")\n",
    "    legend = plt.legend()\n",
    "    for line in legend.get_lines():\n",
    "        line.set_linewidth(1)\n",
    "    # Customize legend border\n",
    "    legend.get_frame().set_edgecolor('lightgray')\n",
    "    legend.get_frame().set_linewidth(0.5)\n",
    "\n",
    "    # Apply axis customization outside the seaborn context\n",
    "    ax = plt.gca()\n",
    "    ax.grid(False)\n",
    "    ax.set_xticks([0])\n",
    "    ax.set_xticklabels(['0'])\n",
    "    ax.set_yticks([])\n",
    "    ax.spines['top'].set_visible(False)\n",
    "    ax.spines['right'].set_visible(False)\n",
    "    ax.spines['bottom'].set_visible(False)\n",
    "    ax.spines['left'].set_visible(False)\n",
    "    ax.set_facecolor('white')\n",
    "    ax.tick_params(axis='x', which='major', length=4, width=0.5, color='black')\n",
    "    ax.tick_params(axis='y', which='both', left=False, labelleft=False)\n",
    "    ax.set_ylim(bottom=0)\n",
    "\n",
    "    plt.savefig(\"plots/intuition_plot_plain.pdf\")\n",
    "    \n",
    "    plt.show()\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
