{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6a2e888",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "faacf1a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv('wandb_pythia_1_4b.csv') # wandb_v3.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8fa461e",
   "metadata": {},
   "outputs": [],
   "source": [
    "cols_to_plot = [\n",
    "    'sweep_pair.num_heads',\n",
    "    'sweep_pair.num_mkeys',\n",
    "    'sweep_pair.num_nkeys',\n",
    "    'sweep_pair.dict_size',\n",
    "    'explained_variance',\n",
    "    'parameters_count',\n",
    "    'accum_num_flops',\n",
    "    'n_tokens',\n",
    "    'performance/recovery_from_zero',\n",
    "    'performance/recovery_from_mean'\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "efffa462",
   "metadata": {},
   "outputs": [],
   "source": [
    "mapping = {\n",
    "    'sweep id': (\"1.4B\", 100 * 1e6),\n",
    "    'sweep id': (\"1.4B\", 500 * 1e6),\n",
    "    'sweep id': (\"1.4B\", 1000 * 1e6),\n",
    "    'sweep id': (\"1.4B\", 1000 * 1e6),\n",
    "    \n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec3d6126",
   "metadata": {},
   "outputs": [],
   "source": [
    "kronsae = []\n",
    "topk = []\n",
    "for i, row in df.iterrows():\n",
    "    if 'exp71' in row['Name'] and row['Sweep'] in mapping.keys():\n",
    "\n",
    "        model_size, ref_tokens = mapping[row['Sweep']]\n",
    "\n",
    "        entry = {\n",
    "            'Name': row['Name'],\n",
    "            'H': row['sweep_pair.num_heads'],\n",
    "            'M': row['sweep_pair.num_mkeys'],\n",
    "            'N': row['sweep_pair.num_nkeys'],\n",
    "            'F': row['sweep_pair.dict_size'],\n",
    "            'Explained Variance': row['explained_variance'],\n",
    "            'CE_from_zero': row['performance/recovery_from_zero'],\n",
    "            'CE_from_mean': row['performance/recovery_from_mean'],\n",
    "            'Parameters': row['parameters_count'],\n",
    "            'FLOPS': row['accum_num_flops'],\n",
    "            'Tokens': row['n_tokens'],\n",
    "            'Reference Tokens': ref_tokens,\n",
    "            'Model size': model_size\n",
    "        }\n",
    "        \n",
    "        if 'kronsae' in row['Name']:\n",
    "            entry['Experts'] = entry['H'] * entry['M']\n",
    "            kronsae.append(entry)\n",
    "        else:\n",
    "            topk.append(entry)\n",
    "\n",
    "kronsae = pd.DataFrame(kronsae)\n",
    "topk = pd.DataFrame(topk)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dabd0a8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_size = '1.5B'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7314f77-b3c3-45ef-9551-b871c167d151",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "from matplotlib.lines import Line2D\n",
    "import numpy as np\n",
    "from scipy.interpolate import PchipInterpolator\n",
    "# Set style\n",
    "sns.set_style(\"whitegrid\")\n",
    "fig, axs = plt.subplots(1, 3, figsize=(15, 5), dpi=250, sharey=True) #, sharex=True, sharey=True)\n",
    "#main_palette = \"cividis\"\n",
    "main_palette = \"colorblind\"\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "def pareto_frontier(xs, ys):\n",
    "    \"\"\"\n",
    "    Extract the 2D Pareto frontier from points (xs, ys).\n",
    "    Keeps points with strictly increasing y as x increases.\n",
    "    Returns arrays (pf_x, pf_y) sorted by pf_x ascending.\n",
    "    \"\"\"\n",
    "    pts = sorted(zip(xs, ys), key=lambda p: p[0])\n",
    "    frontier = []\n",
    "    best_y = -np.inf\n",
    "    for x, y in pts:\n",
    "        if y > best_y:\n",
    "            frontier.append((x, y))\n",
    "            best_y = y\n",
    "    pf_x, pf_y = zip(*frontier)\n",
    "    return np.array(pf_x), np.array(pf_y)\n",
    "\n",
    "from scipy.optimize import curve_fit\n",
    "from scipy.stats import linregress\n",
    "\n",
    "def smooth_pareto_front(xs, ys, deg=2, num=200, scaler = 1.):\n",
    "    \"\"\"\n",
    "    Compute and return a smooth polynomial approximation of the Pareto front.\n",
    "\n",
    "    Args:\n",
    "      xs (array-like):  x-coordinates of all points.\n",
    "      ys (array-like):  y-coordinates of all points.\n",
    "      deg (int):        degree of the fitting polynomial.\n",
    "      num (int):        number of points to sample on the smooth curve.\n",
    "\n",
    "    Returns:\n",
    "      x_dense (np.ndarray):  sorted x-values along the front (length=num).\n",
    "      y_smooth (np.ndarray): fitted y-values p(x_dense) (length=num).\n",
    "    \"\"\"\n",
    "    # 1) Get the raw Pareto-optimal points\n",
    "    print(xs, ys)\n",
    "    #pf_x, pf_y = pareto_frontier(xs, ys)\n",
    "    pf_x, pf_y = np.array(xs), np.array(ys)\n",
    "    print(pf_x, pf_y)\n",
    "\n",
    "    # 2) Fit a degree-deg polynomial via least squares\n",
    "    #    p(x) = c[0]*x^deg + ... + c[deg]\n",
    "    #coeffs = np.polyfit(pf_x, pf_y, deg=deg)     \n",
    "    # def func(x, a, b, c):\n",
    "    #     return a * x**2 + b * x + c  # a and d are redundant\n",
    "    def func(x, a, b, c):\n",
    "        return a * x**2 + b * x + c  # a and d are redundant\n",
    "        \n",
    "    popt, pcov = curve_fit(func, np.log(pf_x), pf_y)\n",
    "\n",
    "    # 3) Evaluate on a dense grid\n",
    "    x_dense = np.linspace(pf_x.min() - scaler * pf_x.min(), pf_x.max() + scaler* pf_x.max(), num)\n",
    "    #y_smooth = np.polyval(coeffs, x_dense)  \n",
    "    y_smooth = func(np.log(x_dense), *popt) #reg.predict(np.array(x_dense))\n",
    "\n",
    "    return x_dense, y_smooth\n",
    "\n",
    "\n",
    "for i, n_tokens in enumerate(sorted(kronsae['Reference Tokens'].unique(), reverse=False)):\n",
    "    plot_df = kronsae[kronsae['Reference Tokens'] == n_tokens]\n",
    "    plot_df = plot_df[plot_df[\"M\"] <= plot_df[\"N\"]]\n",
    "    \n",
    "    ax = axs[i]\n",
    "    # Create plot with log scale for heads (powers of 2)\n",
    "    sns.lineplot(\n",
    "        data=plot_df,\n",
    "        x='Parameters',  # Heads (powers of 2)\n",
    "        y='Explained Variance',\n",
    "        hue='F',\n",
    "        style='M',\n",
    "        markers=True,\n",
    "        dashes=True,\n",
    "        palette=main_palette,# 'tab10',\n",
    "        markersize=8,\n",
    "        linewidth=2,\n",
    "        ax=ax,\n",
    "\n",
    "        legend=i==0\n",
    "    )\n",
    "\n",
    "    # Set logarithmic x-axis with proper tick formatting\n",
    "    #ax.set_xscale('log', base=2)\n",
    "    # ax.set_xticks(plot_df['Parameters'].unique())\n",
    "    ax.get_xaxis().set_major_formatter(plt.ScalarFormatter())\n",
    "\n",
    "    # Customize plot\n",
    "    # ax.set(\n",
    "    #     xlabel='Trainable Parameters',\n",
    "    #     ylabel='Explained Variance'\n",
    "    # )\n",
    "    ax.set_xlabel('Trainable Parameters', fontsize=18)\n",
    "    ax.set_ylabel('Explained Variance', fontsize=18)\n",
    "    ax.grid(True, alpha=0.3)\n",
    "    \n",
    "\n",
    "    #ax.set_title(f\"Reference tokens: {n_tokens // 1e6:.0f}M\")\n",
    "\n",
    "    # plt.ylim(0.8, 0.875)\n",
    "\n",
    "    #cplt = sns.color_palette('tab10', 3)\n",
    "    cplt = sns.color_palette(main_palette, 3)\n",
    "    g_results = sns.scatterplot(topk[topk['Reference Tokens'] == n_tokens], \n",
    "                    x = 'Parameters', \n",
    "                    y = 'Explained Variance', \n",
    "                    hue = 'F', \n",
    "                    s = 400, \n",
    "                    marker = '*', \n",
    "                    palette=cplt, \n",
    "                    legend = False, \n",
    "                    ax=ax)\n",
    "\n",
    "\n",
    "    baseline = topk[topk['Reference Tokens'] == n_tokens]\n",
    "    xs = baseline['Parameters'].values\n",
    "    ys = baseline['Explained Variance'].values\n",
    "    #print(xs, ys)\n",
    "    \n",
    "\n",
    "    # 3b) Fit a smooth Pareto front of degree 2\n",
    "    x_front, y_front = smooth_pareto_front(xs, ys, num=200, scaler=1.9)\n",
    "\n",
    "    # 3c) Draw the fitted curve\n",
    "    ax.plot(\n",
    "        x_front, y_front,\n",
    "        color='black',\n",
    "        lw=1.1,\n",
    "        alpha=0.7,\n",
    "        label='Fitted Pareto front' if i==0 else None\n",
    "    )\n",
    "\n",
    "    # baseline = plot_df[plot_df['Reference Tokens'] == n_tokens]\n",
    "    # xs = baseline['Parameters'].values\n",
    "    # ys = baseline['Explained Variance'].values\n",
    "    idx = [grp['Explained Variance'].idxmax() for D_val, grp in plot_df.groupby(\"F\")]\n",
    "    xs1 = [grp['Parameters'][idx[j]] for j, (D_val, grp) in enumerate(plot_df.groupby(\"F\"))]\n",
    "    ys1 = [grp['Explained Variance'][idx[j]] for j, (D_val, grp) in enumerate(plot_df.groupby(\"F\"))]\n",
    " \n",
    "\n",
    "    # 3b) Fit a smooth Pareto front of degree 2\n",
    "    x_front, y_front = smooth_pareto_front(xs1, ys1,  num=200, scaler=1.9)\n",
    "\n",
    "    all_xs = np.concatenate([baseline['Parameters'].values, plot_df['Parameters'].values])\n",
    "\n",
    "    ax.set_xlim(all_xs.min() - 0.15 * all_xs.min(), all_xs.max() + 0.15 * all_xs.max())\n",
    "\n",
    "    # 3c) Draw the fitted curve\n",
    "    ax.plot(\n",
    "        x_front, y_front,\n",
    "        color='red',\n",
    "        lw=1.1,\n",
    "        alpha=0.7,\n",
    "\n",
    "        label='Fitted Pareto front' if i==0 else None\n",
    "    )\n",
    "    ax.set_title(f\"Token Budget: {n_tokens // 1e6:.0f}M\", fontsize=18)\n",
    "    \n",
    "    ax.get_xaxis().set_major_formatter(plt.ScalarFormatter())\n",
    "    ax.set_xscale(\"log\", base=2)\n",
    "    ax.set_ylim(0.65, 0.84)\n",
    "\n",
    "    ax.tick_params(axis='both', which='major', labelsize=14)\n",
    "    ax.tick_params(axis='both', which='minor', labelsize=10)\n",
    "    \n",
    "\n",
    "axs[0].get_legend().remove()\n",
    "handles, labels = axs[0].get_legend_handles_labels()\n",
    "baseline_handles = [\n",
    "    Line2D([0], [0], \n",
    "           marker=\"*\", \n",
    "           #color=\"green\", \n",
    "           color = \"black\",\n",
    "           linestyle=\"\",\n",
    "           markersize=12, label=f\"TopK SAE\"),\n",
    "]\n",
    "axs[0].legend(\n",
    "    #handles=handles[1:4] + handles[5:],  # Skip the first element (title)\n",
    "    handles=handles[1:4] + handles[5:7] + baseline_handles, #+ baseline_handles,  # Skip the first element (title)\n",
    "    #labels=['D = 32768', 'D = 65536', 'D = 131072', 'M = 2', 'M = 4', 'M=8'],# 'TopK'],\n",
    "    labels=['$F = 2^{15}$', '$F = 2^{16}$', '$F = 2^{17}$', 'm = 2', 'm = 4',  'TopK'],fontsize=14,\n",
    "    loc='upper left',\n",
    ")\n",
    "axs[0].set_xlabel('')\n",
    "axs[2].set_xlabel('')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(f'parameters vs explained variance pythia1.5b.pdf')\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
