{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6a2e888",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "faacf1a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv('final_plots/wandb_65k_by_layers.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8fa461e",
   "metadata": {},
   "outputs": [],
   "source": [
    "cols_to_plot = [\n",
    "    'sweep_pair.num_heads',\n",
    "    'sweep_pair.num_mkeys',\n",
    "    'sweep_pair.num_nkeys',\n",
    "    'sweep_pair.dict_size',\n",
    "    'explained_variance',\n",
    "    'parameters_count',\n",
    "    'accum_num_flops',\n",
    "    'n_tokens'\n",
    "]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "efffa462",
   "metadata": {},
   "outputs": [],
   "source": [
    "mapping = {\n",
    "    'sweep id': (\"1.5B\", 1000 * 1e6),\n",
    "    'sweep id': (\"1.5B\", 500 * 1e6),\n",
    "    'sweep id': (\"1.5B\", 100 * 1e6),\n",
    "    'sweep id': (\"3B\", 1000 * 1e6),\n",
    "    'sweep id': (\"3B\", 500 * 1e6),\n",
    "    'sweep id': (\"3B\", 100 * 1e6),\n",
    "    'sweep id': (\"7B\", 500 * 1e6),\n",
    "    'sweep id': (\"7B\", 100 * 1e6)\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ef4a3af",
   "metadata": {},
   "outputs": [],
   "source": [
    "qwen_data = df[df['model_name'] == 'Qwen/Qwen2.5-1.5B']\n",
    "gemma_data = df[df['model_name'] == 'google/gemma-2-2b']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62daf052",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = []\n",
    "for i, row in df.iterrows():\n",
    "    entry = {\n",
    "        'Name': row['Name'],\n",
    "        'H': row['sweep_pair.num_heads'],\n",
    "        'M': row['sweep_pair.num_mkeys'],\n",
    "        'N': row['sweep_pair.num_nkeys'],\n",
    "        'F': row['sweep_pair.dict_size'],\n",
    "        'Explained Variance': row['explained_variance'],\n",
    "        'Parameters': row['parameters_count'],\n",
    "        'FLOPS': row['accum_num_flops'],\n",
    "        'Tokens': row['n_tokens'],\n",
    "        'Layer': row['layer']\n",
    "    }\n",
    "\n",
    "    if 'kronsae' in row['Name']:\n",
    "        entry['SAE'] = 'KronSAE'\n",
    "    else:\n",
    "        entry['SAE'] = 'TopK'\n",
    "    \n",
    "    data.append(entry)\n",
    "\n",
    "data = pd.DataFrame(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9ea00ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=(5, 4), dpi=250)\n",
    "ax.grid(alpha=0.35)\n",
    "sns.lineplot(data, x='Layer', y='Explained Variance', hue='SAE', ax=ax,palette='tab10', marker='o',)\n",
    "\n",
    "ax.set_xlabel('Layer', fontsize=12)\n",
    "ax.set_ylabel('Explained Variance', fontsize=12)\n",
    "plt.legend()\n",
    "ax.set_xticks(data['Layer'].unique())\n",
    "ax.get_xaxis().set_major_formatter(plt.ScalarFormatter())\n",
    "plt.tight_layout()\n",
    "plt.savefig('layer vs explained variance.pdf')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e11119aa-3283-44c4-8f65-1c95f87395c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv('final_plots/ev_l0.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5022e906",
   "metadata": {},
   "outputs": [],
   "source": [
    "kronsae = []\n",
    "topk = []\n",
    "for i, row in df.iterrows():\n",
    "    entry = {\n",
    "        'Name': row['Name'],\n",
    "        'H': row['sweep_pair.num_heads'],\n",
    "        'M': row['sweep_pair.num_mkeys'],\n",
    "        'N': row['sweep_pair.num_nkeys'],\n",
    "        'F': row['sweep_pair.dict_size'],\n",
    "        'Explained Variance': row['explained_variance'],\n",
    "        'Parameters': row['parameters_count'],\n",
    "        'FLOPS': row['accum_num_flops'],\n",
    "        'Tokens': row['n_tokens'],\n",
    "        'Model': row['model_name'],\n",
    "        'L0': row['topk2']\n",
    "    }\n",
    "    \n",
    "    if 'kronsae' in row['Name']:\n",
    "        kronsae.append(entry)\n",
    "    else:\n",
    "        topk.append(entry)\n",
    "\n",
    "kronsae = pd.DataFrame(kronsae)\n",
    "topk = pd.DataFrame(topk)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "425fbd12-3990-4b41-9d74-1123f7601c93",
   "metadata": {},
   "outputs": [],
   "source": [
    "topk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "844c4785",
   "metadata": {},
   "outputs": [],
   "source": [
    "qwen_kronsae_selected = kronsae[(kronsae[\"Model\"] == \"Qwen/Qwen2.5-1.5B\") & (kronsae[\"M\"] == 2) & (kronsae[\"N\"] == 8)].sort_values(\"L0\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee596bd7-04ae-4860-a2e9-4532291cbf61",
   "metadata": {},
   "outputs": [],
   "source": [
    "google_kronsae_selected = kronsae[(kronsae[\"Model\"] == \"google/gemma-2-2b\") & (kronsae[\"M\"] == 2) & (kronsae[\"N\"] == 8)].sort_values(\"Explained Variance\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d54051f6-59e7-4d5a-94d9-acf5cb880451",
   "metadata": {},
   "outputs": [],
   "source": [
    "merged = pd.concat([qwen_kronsae_selected, google_kronsae_selected])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b38d6b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "from matplotlib.lines import Line2D \n",
    "\n",
    "fig, ax = plt.subplots(figsize=(5, 4), dpi=250)\n",
    "cplt = sns.color_palette('tab10')\n",
    "\n",
    "sns.lineplot(\n",
    "    data=topk[(topk[\"Model\"] == \"google/gemma-2-2b\")],\n",
    "    x='L0',\n",
    "    y='Explained Variance',\n",
    "    marker='o',\n",
    "    lw=1.2,\n",
    "    color=cplt[0],\n",
    "    ls='--',\n",
    "    markersize=10,\n",
    ")\n",
    "\n",
    "sns.lineplot(\n",
    "    data=merged[(merged[\"Model\"] == \"google/gemma-2-2b\")],\n",
    "    x='L0',\n",
    "    y='Explained Variance',\n",
    "    markers=True,\n",
    "    marker='^',\n",
    "    lw=1.2,\n",
    "    color=cplt[1],\n",
    "    ls='--',\n",
    "    markersize=10,\n",
    ")\n",
    "\n",
    "sns.lineplot(\n",
    "    data=topk[(topk[\"Model\"] == \"Qwen/Qwen2.5-1.5B\")],\n",
    "    x='L0', \n",
    "    y='Explained Variance',\n",
    "    marker='o',\n",
    "    lw=1.2,\n",
    "    color=cplt[0],\n",
    "    ls='-',\n",
    "    markersize=10,\n",
    ")\n",
    "\n",
    "\n",
    "\n",
    "sns.lineplot(\n",
    "    data=merged[(merged[\"Model\"] == \"Qwen/Qwen2.5-1.5B\")],\n",
    "    x='L0',\n",
    "    y='Explained Variance',\n",
    "    markers=True,\n",
    "    marker='^',\n",
    "    lw=1.2,\n",
    "    color=cplt[1],\n",
    "    ls='-',\n",
    "    markersize=10,\n",
    ")\n",
    "\n",
    "ax.set_xscale('log', base=2)\n",
    "ax.set_xticks(topk['L0'].unique())\n",
    "ax.get_xaxis().set_major_formatter(plt.ScalarFormatter())\n",
    "\n",
    "ax.set_xlabel(r'$\\ell_0$ Sparsity', fontsize=12)\n",
    "ax.set_ylabel('Explained Variance', fontsize=12)\n",
    "ax.grid(True, alpha=0.35, which='both')\n",
    "\n",
    "handles, labels = ax.get_legend_handles_labels()\n",
    "star_marker = Line2D(\n",
    "    [0], [0], color= cplt[0], marker='o', linestyle=\"\", markersize=4, label='TopK'\n",
    ")\n",
    "kronsae_marker = Line2D(\n",
    "    [0], [0], color= cplt[1], marker='^', linestyle=\"\", markersize=4, label='TopK'\n",
    ")\n",
    "qwen_marker = Line2D(\n",
    "    [0], [0], color='gray', marker='', linestyle=\"-\", markersize=4, label='TopK'\n",
    ")\n",
    "\n",
    "gemma_marker = Line2D(\n",
    "    [0], [0], color='gray', marker='', linestyle=\"--\", markersize=4, label='TopK'\n",
    ")\n",
    "\n",
    "custom_handles =  [star_marker, kronsae_marker]  + [qwen_marker, gemma_marker]\n",
    "custom_labels = ['TopK', 'KronSAE', 'Qwen2.5 1.5B', 'Gemma 2 2B', ]\n",
    "\n",
    "ax.legend(\n",
    "    handles=custom_handles,\n",
    "    labels=custom_labels,\n",
    "    loc='upper left',\n",
    ")\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(f'l0 vs explained variance.pdf')\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
