{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6a2e888",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50eb4814-d952-453e-a0f8-934dca96d0f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install scipy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "faacf1a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv('wandb_v3.csv') # a dump of wandb runs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8fa461e",
   "metadata": {},
   "outputs": [],
   "source": [
    "cols_to_plot = [\n",
    "    'sweep_pair.num_heads',\n",
    "    'sweep_pair.num_mkeys',\n",
    "    'sweep_pair.num_nkeys',\n",
    "    'sweep_pair.dict_size',\n",
    "    'explained_variance',\n",
    "    'parameters_count',\n",
    "    'accum_num_flops',\n",
    "    'n_tokens',\n",
    "    'performance/recovery_from_zero',\n",
    "    'performance/recovery_from_mean'\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "efffa462",
   "metadata": {},
   "outputs": [],
   "source": [
    "mapping = {\n",
    "    'sweep id': (\"1.5B\", 1000 * 1e6),\n",
    "    'sweep id': (\"1.5B\", 1000 * 1e6),\n",
    "    'sweep id': (\"1.5B\", 500 * 1e6),\n",
    "    'sweep id': (\"1.5B\", 500 * 1e6),\n",
    "    'sweep id': (\"1.5B\", 100 * 1e6),\n",
    "    'sweep id': (\"1.5B\", 100 * 1e6),\n",
    "    \n",
    "    'sweep id': (\"1.5B\", 100 * 1e6),\n",
    "    'sweep id': (\"1.5B\", 500 * 1e6),\n",
    "    'sweep id': (\"1.5B\", 1000 * 1e6),\n",
    "    \n",
    "    'sweep id': (\"3B\", 1000 * 1e6),\n",
    "    'sweep id': (\"3B\", 500 * 1e6),\n",
    "    'sweep id': (\"3B\", 100 * 1e6),\n",
    "    \n",
    "    'sweep id': (\"7B\", 500 * 1e6),\n",
    "    'sweep id': (\"7B\", 100 * 1e6),\n",
    "    'sweep id': (\"1.5B\", 1000 * 1e6),\n",
    "\n",
    "    'sweep id': (\"1.4B\", 100 * 1e6),\n",
    "    'sweep id': (\"1.4B\", 500 * 1e6),\n",
    "    'sweep id': (\"1.4B\", 1000 * 1e6),\n",
    "    'sweep id': (\"1.4B\", 1000 * 1e6),\n",
    "    \n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec3d6126",
   "metadata": {},
   "outputs": [],
   "source": [
    "kronsae = []\n",
    "topk = []\n",
    "for i, row in df.iterrows():\n",
    "    if 'exp51' in row['Name'] and row['Sweep'] in mapping.keys():\n",
    "\n",
    "        model_size, ref_tokens = mapping[row['Sweep']]\n",
    "\n",
    "        entry = {\n",
    "            'Name': row['Name'],\n",
    "            'H': row['sweep_pair.num_heads'],\n",
    "            'M': row['sweep_pair.num_mkeys'],\n",
    "            'N': row['sweep_pair.num_nkeys'],\n",
    "            'F': row['sweep_pair.dict_size'],\n",
    "            'Explained Variance': row['explained_variance'],\n",
    "            'CE_from_zero': row['performance/recovery_from_zero'],\n",
    "            'CE_from_mean': row['performance/recovery_from_mean'],\n",
    "            'Parameters': row['parameters_count'],\n",
    "            'FLOPS': row['accum_num_flops'],\n",
    "            'Tokens': row['n_tokens'],\n",
    "            'Reference Tokens': ref_tokens,\n",
    "            'Model size': model_size\n",
    "        }\n",
    "        \n",
    "        if 'kronsae' in row['Name']:\n",
    "            entry['Experts'] = entry['H'] * entry['M']\n",
    "            kronsae.append(entry)\n",
    "        else:\n",
    "            topk.append(entry)\n",
    "\n",
    "kronsae = pd.DataFrame(kronsae)\n",
    "topk = pd.DataFrame(topk)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd91fa66-4c3f-4e58-a454-c459b44a7230",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "from matplotlib.lines import Line2D\n",
    "import numpy as np\n",
    "\n",
    "sns.set_style(\"whitegrid\")\n",
    "fig, axs = plt.subplots(1, 3, figsize=(15, 5), dpi=250, sharey='row')\n",
    "\n",
    "for i, n_tokens in enumerate(sorted(kronsae['Reference Tokens'].unique(), reverse=False)):\n",
    "    plot_df = kronsae[kronsae['Reference Tokens'] == n_tokens]\n",
    "    plot_df = plot_df[plot_df[\"M\"] <= plot_df[\"N\"]]\n",
    "    ax = axs[i]\n",
    "\n",
    "    sns.lineplot(\n",
    "        data=plot_df,\n",
    "        x='H',  # Heads (powers of 2)\n",
    "        y='Explained Variance',\n",
    "        hue='F',\n",
    "        style='M',\n",
    "        markers=True, \n",
    "        dashes=True,\n",
    "        palette='tab10',\n",
    "        markersize=11,\n",
    "        linewidth=1.5,\n",
    "        ax=ax,\n",
    "        legend=i==0\n",
    "    )\n",
    "\n",
    "    ax.set_xscale('log', base=2)\n",
    "    ax.set_xticks(plot_df['H'].unique())\n",
    "    ax.get_xaxis().set_major_formatter(plt.ScalarFormatter())\n",
    "\n",
    "    # Customize plot\n",
    "    ax.set_xlabel('Trainable Parameters', fontsize=18)\n",
    "    ax.set_ylabel('Explained Variance', fontsize=18)\n",
    "    ax.grid(True, alpha=0.3, which='both')\n",
    "\n",
    "    ax.set_title(f\"Token Budget: {n_tokens // 1e6:.0f}M\", fontsize=18)\n",
    "    ax.tick_params(axis='y', which='major', labelsize=14)\n",
    "    ax.tick_params(axis='x', which='major', labelsize=12)\n",
    "    ax.tick_params(axis='both', which='minor', labelsize=10)\n",
    "\n",
    "axs[0].set_xlabel('')\n",
    "axs[0].set_xlabel('')\n",
    "\n",
    "handles, labels = axs[0].get_legend_handles_labels()\n",
    "\n",
    "axs[0].legend(\n",
    "    handles=handles[1:4] + handles[5:],\n",
    "    labels=['$F = 2^{15}$', '$F = 2^{16}$', '$F = 2^{17}$', 'm = 2', 'm = 4', 'm = 8'],fontsize=14,\n",
    "    loc='upper left'\n",
    ")\n",
    "\n",
    "axs[0].set_xlabel('')\n",
    "axs[2].set_xlabel('')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(f'heads vs explained variance qwen1.5B.pdf')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a01a42e-058a-453b-8b45-de6cfd39bfed",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "from matplotlib.lines import Line2D\n",
    "import numpy as np\n",
    "from scipy.interpolate import PchipInterpolator\n",
    "sns.set_style(\"whitegrid\")\n",
    "fig, axs = plt.subplots(1, 3, figsize=(15, 5), dpi=250, sharey=True)\n",
    "\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "def pareto_frontier(xs, ys):\n",
    "    \"\"\"\n",
    "    Extract the 2D Pareto frontier from points (xs, ys).\n",
    "    Keeps points with strictly increasing y as x increases.\n",
    "    Returns arrays (pf_x, pf_y) sorted by pf_x ascending.\n",
    "    \"\"\"\n",
    "    pts = sorted(zip(xs, ys), key=lambda p: p[0])\n",
    "    frontier = []\n",
    "    best_y = -np.inf\n",
    "    for x, y in pts:\n",
    "        if y > best_y:\n",
    "            frontier.append((x, y))\n",
    "            best_y = y\n",
    "    pf_x, pf_y = zip(*frontier)\n",
    "    return np.array(pf_x), np.array(pf_y)\n",
    "\n",
    "from scipy.optimize import curve_fit\n",
    "from scipy.stats import linregress\n",
    "\n",
    "def smooth_pareto_front(xs, ys, deg=2, num=200, scaler = 1.):\n",
    "    \"\"\"\n",
    "    Compute and return a smooth polynomial approximation of the Pareto front.\n",
    "\n",
    "    Args:\n",
    "      xs (array-like):  x-coordinates of all points.\n",
    "      ys (array-like):  y-coordinates of all points.\n",
    "      deg (int):        degree of the fitting polynomial.\n",
    "      num (int):        number of points to sample on the smooth curve.\n",
    "\n",
    "    Returns:\n",
    "      x_dense (np.ndarray):  sorted x-values along the front (length=num).\n",
    "      y_smooth (np.ndarray): fitted y-values p(x_dense) (length=num).\n",
    "    \"\"\"\n",
    "    print(xs, ys)\n",
    "    pf_x, pf_y = np.array(xs), np.array(ys)\n",
    "    print(pf_x, pf_y)\n",
    "\n",
    "    def func(x, a, b, c):\n",
    "        return a * x**2 + b * x + c \n",
    "        \n",
    "    popt, pcov = curve_fit(func, np.log2(pf_x), pf_y)\n",
    "\n",
    "    x_dense = np.linspace(pf_x.min() - scaler * pf_x.min(), pf_x.max() + scaler* pf_x.max(), num)\n",
    "    y_smooth = func(np.log2(x_dense), *popt) \n",
    "\n",
    "    return x_dense, y_smooth\n",
    "\n",
    "\n",
    "for i, n_tokens in enumerate(sorted(kronsae['Reference Tokens'].unique(), reverse=False)):\n",
    "    plot_df = kronsae[kronsae['Reference Tokens'] == n_tokens]\n",
    "    plot_df = plot_df[plot_df[\"M\"] <= plot_df[\"N\"]]\n",
    "    \n",
    "    ax = axs[i]\n",
    "    sns.lineplot(\n",
    "        data=plot_df,\n",
    "        x='Parameters',  # Heads (powers of 2)\n",
    "        y='Explained Variance',\n",
    "        hue='F',\n",
    "        style='M',\n",
    "        markers=True,\n",
    "        dashes=True,\n",
    "        palette='tab10',\n",
    "        markersize=8,\n",
    "        linewidth=2,\n",
    "        ax=ax,\n",
    "        legend=i==0\n",
    "    )\n",
    "\n",
    "    ax.get_xaxis().set_major_formatter(plt.ScalarFormatter())\n",
    "\n",
    "    ax.set_xlabel('Trainable Parameters', fontsize=18)\n",
    "    ax.set_ylabel('Explained Variance', fontsize=18)\n",
    "    ax.grid(True, alpha=0.3)\n",
    "    \n",
    "\n",
    "    cplt = sns.color_palette('tab10', 3)\n",
    "    g_results = sns.scatterplot(topk[topk['Reference Tokens'] == n_tokens], \n",
    "                    x = 'Parameters', \n",
    "                    y = 'Explained Variance', \n",
    "                    hue = 'F', \n",
    "                    s = 400, \n",
    "                    marker = '*', \n",
    "                    palette=cplt, \n",
    "                    legend = False, \n",
    "                    ax=ax)\n",
    "\n",
    "\n",
    "    baseline = topk[topk['Reference Tokens'] == n_tokens]\n",
    "    xs = baseline['Parameters'].values\n",
    "    ys = baseline['Explained Variance'].values\n",
    "\n",
    "    \n",
    "    x_front, y_front = smooth_pareto_front(xs, ys, num=200, scaler=1.9)\n",
    "\n",
    "    ax.plot(\n",
    "        x_front, y_front,\n",
    "        color='black',\n",
    "        lw=1.1,\n",
    "        alpha=0.7,\n",
    "        label='Fitted Pareto front' if i==0 else None\n",
    "    )\n",
    "\n",
    "    idx = [grp['Explained Variance'].idxmax() for D_val, grp in plot_df.groupby(\"F\")]\n",
    "    xs1 = [grp['Parameters'][idx[j]] for j, (D_val, grp) in enumerate(plot_df.groupby(\"F\"))]\n",
    "    ys1 = [grp['Explained Variance'][idx[j]] for j, (D_val, grp) in enumerate(plot_df.groupby(\"F\"))]\n",
    " \n",
    "    x_front, y_front = smooth_pareto_front(xs1, ys1,  num=200, scaler=1.9)\n",
    "\n",
    "    all_xs = np.concatenate([baseline['Parameters'].values, plot_df['Parameters'].values])\n",
    "\n",
    "    ax.set_xlim(all_xs.min() - 0.15 * all_xs.min(), all_xs.max() + 0.15 * all_xs.max())\n",
    "\n",
    "    ax.plot(\n",
    "        x_front, y_front,\n",
    "        color='red',\n",
    "        lw=1.1,\n",
    "        alpha=0.7,\n",
    "        label='Fitted Pareto front' if i==0 else None\n",
    "    )\n",
    "    ax.set_title(f\"Token Budget: {n_tokens // 1e6:.0f}M\", fontsize=18)\n",
    "    \n",
    "    ax.get_xaxis().set_major_formatter(plt.ScalarFormatter())\n",
    "    ax.set_xscale(\"log\", base=2)\n",
    "    ax.set_ylim(0.775, 0.87)\n",
    "    ax.tick_params(axis='both', which='major', labelsize=14)\n",
    "    ax.tick_params(axis='both', which='minor', labelsize=10)\n",
    "    \n",
    "\n",
    "axs[0].get_legend().remove()\n",
    "handles, labels = axs[0].get_legend_handles_labels()\n",
    "baseline_handles = [\n",
    "    Line2D([0], [0], \n",
    "           marker=\"*\", \n",
    "           #color=\"green\", \n",
    "           color = \"black\",\n",
    "           linestyle=\"\",\n",
    "           markersize=12, label=f\"TopK SAE\"),\n",
    "]\n",
    "\n",
    "axs[0].legend(\n",
    "    handles=handles[1:4] + new_handles + baseline_handles,\n",
    "    labels=['$F = 2^{15}$', '$F = 2^{16}$', '$F = 2^{17}$', 'm = 2', 'm = 4', 'm = 8', 'TopK'], fontsize=14,\n",
    "    loc='upper left',\n",
    ")\n",
    "axs[0].set_xlabel('')\n",
    "axs[2].set_xlabel('')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(f'parameters vs explained variance qwen1.5b.pdf')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1771d518",
   "metadata": {},
   "outputs": [],
   "source": [
    "new_df = pd.read_csv(\"wandb.csv\")\n",
    "\n",
    "data = []\n",
    "\n",
    "sae_type_to_name = {\n",
    "    \"topk\": \"TopK\",\n",
    "    \"mul_fractal_topk\": \"KronSAE\",\n",
    "    \"jumprelu\": \"JumpReLU\",\n",
    "    \"mul_fractal_jumprelu\": \"KronSAE JumpReLU\"\n",
    "}\n",
    "\n",
    "for i, row in new_df.iterrows():\n",
    "\n",
    "    ref_tokens = row['Name'].split(\"_\")[-2]\n",
    "\n",
    "    entry = {\n",
    "        'Name': row['Name'],\n",
    "        \"Setup\": f\"{row['dict_size']} / {ref_tokens}\",\n",
    "        'H': row['num_heads'],\n",
    "        'M': row['num_mkeys'],\n",
    "        'N': row['num_nkeys'],\n",
    "        'F': row['dict_size'],\n",
    "        'Runtime': row['Runtime'],\n",
    "        'SAE': sae_type_to_name[row['sae_type']],\n",
    "        'Explained Variance': row['explained_variance'],\n",
    "        'Parameters': row['parameters_count'],\n",
    "        'FLOPS': row['accum_num_flops'],\n",
    "        'Tokens': row['n_tokens_true'],\n",
    "        \"Reference tokens\": ref_tokens,\n",
    "        'Model size': row['model_name'].split('-')[-1],\n",
    "    }\n",
    "    data.append(entry)\n",
    "\n",
    "data = pd.DataFrame(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "981df067",
   "metadata": {},
   "outputs": [],
   "source": [
    "import colorsys\n",
    "import numpy as np\n",
    "\n",
    "fig, axs = plt.subplots(1, 3, figsize=(12, 4), dpi=250, sharex=True, sharey=True)\n",
    "tab10 = plt.cm.tab10.colors  # Get base colors: [(0.12, 0.47, 0.71, 1), ...]\n",
    "\n",
    "for i, dict_size in enumerate([2**15, 2**16, 2**17]):\n",
    "    subset = data[(data['F'] == dict_size) & (data[\"Model size\"] == \"1.5B\")]\n",
    "    \n",
    "    # Process KronSAE data\n",
    "    kron = subset[(subset[\"SAE\"] == \"KronSAE\") & (subset[\"M\"] < 16)].copy()\n",
    "    if not kron.empty:\n",
    "        kron['M_float'] = kron['M'].astype(float)\n",
    "        m_vals = sorted(kron['M_float'].unique())\n",
    "        kron['M'] = kron['M_float'].apply(lambda x: str(int(x)))  # Clean integer strings\n",
    "        \n",
    "        # Generate shades: lighter (small M) → darker (large M)\n",
    "        base_rgb = tab10[i][:3]  # Base color (R,G,B) for this dict size\n",
    "        lightness_vals = np.linspace(0.8, 0.4, len(m_vals))  # Light to dark\n",
    "        kron_colors = []\n",
    "        for l_val in lightness_vals:\n",
    "            h, l, s = colorsys.rgb_to_hls(*base_rgb)\n",
    "            new_rgb = colorsys.hls_to_rgb(h, l_val, s)\n",
    "            kron_colors.append(new_rgb)\n",
    "    else:\n",
    "        m_vals = []\n",
    "        kron_colors = []\n",
    "    \n",
    "    # Process TopK (single bar)\n",
    "    topk = subset[subset[\"SAE\"] == \"TopK\"].copy()\n",
    "    topk['M'] = 'TopK'\n",
    "    topk_color = (0.267, 0.267, 0.267)  # #444444 (dark gray)\n",
    "    \n",
    "    # Combine datasets\n",
    "    combined = pd.concat([kron, topk])\n",
    "    m_order = [str(int(m)) for m in m_vals] + ['TopK']  # Numeric sort for KronSAE\n",
    "    palette = kron_colors + [topk_color]  # KronSAE shades + TopK color\n",
    "    \n",
    "    # Plot with custom palette\n",
    "    sns.barplot(\n",
    "        combined,\n",
    "        x=\"Reference tokens\",\n",
    "        y=\"Explained Variance\",\n",
    "        hue=\"M\",\n",
    "        estimator=\"max\",\n",
    "        ax=axs[i],\n",
    "        legend=i==0,\n",
    "        # order=\"500m, 1000m\",\n",
    "        hue_order=m_order,\n",
    "        palette=palette,\n",
    "        order=[\"500m\", \"1000m\"],\n",
    "        dodge=True\n",
    "    )\n",
    "    \n",
    "    axs[i].set_ylim(0.76, 0.9)\n",
    "    axs[i].set_title(f\"Dictionary: {dict_size}\")\n",
    "    axs[i].grid(alpha=0.35, axis='y', linestyle='--')\n",
    "\n",
    "    if i != 1:\n",
    "        axs[i].set_xlabel(\"\")\n",
    "\n",
    "# Customize legend (last subplot only)\n",
    "handles, labels = axs[0].get_legend_handles_labels()\n",
    "axs[0].legend(\n",
    "    handles=[handles[0], handles[-2], handles[-1]],  # Smallest M + TopK\n",
    "    labels=[f\"m={labels[0]}\", f\"m={labels[-2]}\", labels[-1]],\n",
    "    loc=\"upper left\",\n",
    "    frameon=True,\n",
    "    framealpha=0.9\n",
    ")\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"maximum_performance.pdf\")\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
