{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8c1d8b5-ab74-47a8-a944-8e96344b26e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import wandb\n",
    "from transformer_lens import HookedTransformer\n",
    "import torch\n",
    "import os\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.cm as cm\n",
    "from tqdm import tqdm\n",
    "from sae_lens import LanguageModelSAERunnerConfig, SAETrainingRunner, SAE\n",
    "from dataclasses import dataclass\n",
    "from stitching.stitching_utils import open_experiment\n",
    "from stitching.losses import next_token_cross_entropy_loss\n",
    "import seaborn as sns\n",
    "from matplotlib.lines import Line2D\n",
    "import matplotlib.colors as mcolors\n",
    "import matplotlib\n",
    "import pandas as pd\n",
    "device = \"cuda\"\n",
    "\n",
    "print(\"Using device:\", device)\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
    "import yaml\n",
    "with open('global_config.yaml') as global_stream:\n",
    "    global_cfg = yaml.safe_load(global_stream)\n",
    "CACHE_DIR = global_cfg['CACHE_DIR']\n",
    "torch.set_grad_enabled(False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d0cbf6d8-e4a6-42ff-ac4f-ded0aa6c8eb9",
   "metadata": {},
   "source": [
    "# load from saved"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7652740a-c556-4858-bbfb-938f2b08720b",
   "metadata": {},
   "outputs": [],
   "source": [
    "histories = {\n",
    "    k: pd.read_csv(f'data/sae_metrics_{k}.csv') for k in run_id_map.keys()\n",
    "}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5cdfd167-c894-484b-b006-148d079f0a41",
   "metadata": {},
   "outputs": [],
   "source": [
    "rolling_iterations = 1000\n",
    "for k, v_df in histories.items():\n",
    "    v_df['moving_avg'] = (1-v_df['metrics/explained_variance']).rolling(window=1000).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "094898db-2864-4a60-b671-d4f39724114e",
   "metadata": {},
   "outputs": [],
   "source": [
    "estimated_flop_counts_cached = {\n",
    "    '4k': 1.546188e+16,\n",
    "    '8k': 3.092376e+16,\n",
    "    '16k': 6.184753e+16,\n",
    "    '32k': 1.236951e+17,\n",
    "    '65k': 2.473901e+17\n",
    "}\n",
    "# 8.2 * 10^16 for pythia70m\n",
    "stitch_estimated_flops = 7.247757e+14  # for one epoch\n",
    "indices_map = {\n",
    "    '4k': 0,\n",
    "    '8k': 1,\n",
    "    '16k': 2,\n",
    "    '32k': 3,\n",
    "    '65k': 4\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1efcbfcb-da22-4f5c-9a8c-f2e57be6a1c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_first_step(df, loss_col, loss_levels):\n",
    "    \"\"\"Find the first step where the loss hits each level.\"\"\"\n",
    "    steps = []\n",
    "    for level in loss_levels:\n",
    "        step = df[df[loss_col] <= level].index.min()\n",
    "        #print(step)\n",
    "        steps.append(step if pd.notnull(step) else np.nan)\n",
    "    return steps\n",
    "colors = plt.cm.viridis(np.linspace(0.1, 0.9, len(estimated_flop_counts_cached))[::-1])  # A range of blue tones\n",
    "\n",
    "# Define the loss levels you care about\n",
    "matplotlib.rcParams.update({'font.size': 14})\n",
    "loss_levels = np.arange(0.08, 0.3, 0.01)\n",
    "for size in estimated_flop_counts_cached.keys():\n",
    "    stitched_df = histories[f'stitched_{size}'].set_index('_step')\n",
    "    scratch_df = histories[f'scratch_{size}'].set_index('_step')\n",
    "\n",
    "    # Find first step for each curve\n",
    "    steps_curve_1 = np.array(find_first_step(stitched_df, \"moving_avg\", loss_levels)) / 120_000 * estimated_flop_counts_cached[size] + stitch_estimated_flops*2\n",
    "    steps_curve_2 = np.array(find_first_step(scratch_df, \"moving_avg\", loss_levels)) / 120_000 *  estimated_flop_counts_cached[size]\n",
    "\n",
    "    # Compute the difference in steps\n",
    "    step_differences = (np.array(steps_curve_2) - np.array(steps_curve_1)) / np.array(steps_curve_2)\n",
    "\n",
    "    # Plot the result\n",
    "\n",
    "    plt.plot(1-loss_levels[::-1], step_differences[::-1], label=f'{size}, k=64', color=colors[indices_map[size]])\n",
    "plt.xlabel(\"Explained Variance Threshold\")\n",
    "plt.ylabel(\"Relative FLOPs Difference\")\n",
    "plt.grid()\n",
    "plt.legend()\n",
    "plt.ylim(-1)\n",
    "plt.axhline(0, color='black', linestyle='dashed', alpha=0.5)\n",
    "\n",
    "#plt.savefig(\"results/figures/70m_160m_relative_flops_diff.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b60c262-0ad9-480b-8493-17066d7240ae",
   "metadata": {},
   "source": [
    "## Scaling Law"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f956581f-f75d-4703-9273-b1769d904fb7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataclasses import dataclass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95ca787e-0b88-4432-8b30-cdcd7de071bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "@dataclass\n",
    "class Experiment:\n",
    "    flop_count: int\n",
    "    history: pd.DataFrame"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "554c3262-792c-48a8-941a-4de37d99c051",
   "metadata": {},
   "outputs": [],
   "source": [
    "exp_map = {}\n",
    "for k, (wandb_id, sae_id) in run_id_map.items():\n",
    "    history = histories[k]\n",
    "    sae_size = k.split('_')[-1]\n",
    "    #_, indices = cut_series_on_convergence(history['losses/mse_loss'], 100, 1e-4, 1000)\n",
    "    #print(k, indices)\n",
    "    #history_df = history.iloc[indices].copy()\n",
    "    history_df = history \n",
    "    #if 'stitched' in k and k != 'stitched_65k':\n",
    "    #    sae_id = run_id_map[f'scratch_{k.split('_')[-1]}'][1]\n",
    "    #with open(f'pythia-160m-sae-topk-checkpoints/{sae_id}/flop_count.txt', 'r') as flop_file:\n",
    "    #    total_flop_count = int(flop_file.readline())\n",
    "    total_flop_count = estimated_flop_counts_cached[sae_size]\n",
    "    history_df['flops'] = history_df['_step'] / 120000 * total_flop_count \n",
    "    #if 'stitched' in k:\n",
    "    #     history_df['flops'] += stitch_estimated_flops # (approximate cost of stitch)\n",
    "    exp_map[k] = Experiment(total_flop_count, history_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95617abf-f0b6-4d5c-9867-6062dbcc5a5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_best_at_threshes(histories, key, thresholds):\n",
    "    # expects flops key in histories\n",
    "    best = []\n",
    "    for threshold in thresholds:\n",
    "        lol = []\n",
    "        for history in histories:\n",
    "            if history['flops'].iloc[0] <= threshold and threshold <= history['flops'].iloc[-1]:\n",
    "                lol.append(history[history['flops'] >= threshold][key].iloc[0])\n",
    "        best_mse = min(lol)\n",
    "        best.append(best_mse)\n",
    "    return best"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd2fac4d-6761-4a73-a797-4bddfab2d877",
   "metadata": {},
   "outputs": [],
   "source": [
    "scratches = [exp.history for k, exp in exp_map.items() if 'scratch' in k]\n",
    "stitches = [exp.history for k, exp in exp_map.items() if 'stitch' in k]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4bc8d33a-420b-4b49-89ef-1bb5e657d876",
   "metadata": {},
   "outputs": [],
   "source": [
    "threshes = np.logspace(15, 17, 20)\n",
    "#threshes = np.linspace(10**15, 10**17, 10)#10 ** (np.linspace(15, 17, 20))\n",
    "scratch_line = np.array(find_best_at_threshes(scratches, 'moving_avg', threshes))\n",
    "stitch_line = np.array(find_best_at_threshes(stitches, 'moving_avg', threshes))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10b424ba-9e6a-464e-8501-1bad76137c09",
   "metadata": {},
   "source": [
    "Try fitting a proper scaling law. We don't have enough high-compute data points to get the irreducible loss to be correct."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eda7e610-c7e7-4923-93f1-063d5963c521",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from scipy.optimize import curve_fit\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "def scaling_law(C, e, alpha, beta):\n",
    "    \"\"\"Model: L = e + alpha * C**beta\"\"\"\n",
    "    return e + alpha * C**(-beta)\n",
    "\n",
    "def fit_scaling_law(C, L):\n",
    "    p0 = [min(L)*0.9, (max(L)-min(L)), 0.1]\n",
    "    bounds = ([0,0,0], [np.inf, np.inf, np.inf])\n",
    "    \n",
    "    params, covariance = curve_fit(scaling_law, C, L, p0=p0, bounds=bounds,maxfev=100000)\n",
    "    e_fit, alpha_fit, beta_fit = params\n",
    "    param_errors = np.sqrt(np.diag(covariance))\n",
    "    \n",
    "    print(\"Fitted parameters:\")\n",
    "    print(f\"  Irreducible loss (e):     {e_fit:.4e} ± {param_errors[0]:.4e}\")\n",
    "    print(f\"  Scale coefficient (α):    {alpha_fit:.4e} ± {param_errors[1]:.4e}\")\n",
    "    print(f\"  Exponent (β):             {beta_fit:.4f} ± {param_errors[2]:.4f}\")\n",
    "    \n",
    "    # Plot data and fit on log-log axes\n",
    "    return params\n",
    "scratch_params = fit_scaling_law(threshes, scratch_line)\n",
    "stitch_params = fit_scaling_law(threshes, stitch_line)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "09ce99c4-03b5-416a-8111-a3554bb9a104",
   "metadata": {},
   "source": [
    "### Generate plot with fitted scaling law"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae2bbacd-b920-43d1-bd35-8e4cf9c5de82",
   "metadata": {},
   "outputs": [],
   "source": [
    "import scipy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c01c71d-2883-4bf3-bd89-d0cf88430a8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "palette1 = plt.cm.Blues(np.linspace(0.4, 0.9, len(indices_map)))  # A range of blue tones\n",
    "palette2 = plt.cm.Oranges(np.linspace(0.4, 0.9, len(indices_map)))  # A range of orange tones\n",
    "fig,ax = plt.subplots(1,1,figsize=(6,4))\n",
    "key = 'moving_avg' # 'losses/mse_loss'\n",
    "for k, exp in exp_map.items():\n",
    "    size_index = indices_map[k.split('_')[-1]]\n",
    "    palette = palette1 if 'scratch' in k else palette2\n",
    "    plt.plot(np.log10(exp.history['flops']), np.log10(exp.history[key]), color=palette[size_index],alpha=0.75)\n",
    "    #scatter['scratch_x'].append(np.log10(exp.history['flops'].iloc[-1]))\n",
    "    #scatter['scratch_y'].append(np.log10(exp.history[key].iloc[-1]))\n",
    "yticks = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6])\n",
    "plt.yticks(np.log10(yticks), yticks)\n",
    "plt.xticks([14, 15, 16, 17], [r\"$10^{14}$\", r\"$10^{15}$\", r\"$10^{16}$\", r\"$10^{17}$\"])\n",
    "plt.xlabel(\"FLOPs\")\n",
    "plt.ylabel(\"Fraction of Unexplained Variance\")\n",
    "#sns.lineplot(x=np.log10(threshes), y=np.log10(scaling_law(threshes, *scratch_params)), label='scratch', color='tab:blue')\n",
    "#sns.lineplot(x=np.log10(threshes), y=np.log10(scaling_law(threshes, *stitch_params)), label='stitch', color='tab:orange')\n",
    "p = sns.regplot(x=np.log10(threshes), y=np.log10(stitch_line), ci=None, truncate=False, label='stitch',color='tab:orange', scatter=False, line_kws={'linestyle':'dashed', 'alpha': 0.5})\n",
    "p = sns.regplot(x=np.log10(threshes), y=np.log10(scratch_line), ci=None, truncate=False, label='scratch', color='tab:blue', scatter=False,line_kws={'linestyle':'dashed', 'alpha': 0.5})\n",
    "\n",
    "def fit_regress_on_output(p, idx):\n",
    "    print(len(p.get_lines()))\n",
    "    xdat = p.get_lines()[idx].get_xdata()\n",
    "    ydat = p.get_lines()[idx].get_ydata()\n",
    "    nans = np.isnan(xdat) | np.isnan(ydat)\n",
    "    return scipy.stats.linregress(x=xdat[~nans],y=ydat[~nans])\n",
    "slope_stitch, intercept_stitch, _,p_value,std_err= fit_regress_on_output(p, -2)\n",
    "slope_scratch, intercept_scratch, _,p_value,std_err = fit_regress_on_output(p, -1)\n",
    "print(\"stitch\", slope_stitch, intercept_stitch, f\"L(C) = {10**intercept_stitch}C^{slope_stitch}\")\n",
    "print(\"scratch\", slope_scratch, intercept_scratch, f\"L(C) = {10**intercept_scratch}C^{slope_scratch}\")\n",
    "\n",
    "legend_handles = [\n",
    "    Line2D([0], [0], color=palette1[1], lw=2, label=f'random'),# L = {scratch_params[0]:.2e} + {scratch_params[1]:.2e} C^-{scratch_params[2]:.2f}'),\n",
    "    Line2D([0], [0], color=palette2[1], lw=2, label=f'stitch'), #L = {stitch_params[0]:.2e} + {stitch_params[1]:.2e} C^-{stitch_params[2]:.2f}'),\n",
    "]\n",
    "\n",
    "# Create a ScalarMappable for the colorbar\n",
    "sizes = np.array([4096, 8192, 16384, 32768, 65536])\n",
    "# Create two separate colormaps\n",
    "blue_cmap = mcolors.LinearSegmentedColormap.from_list(\"Blues_subset\", palette1)\n",
    "orange_cmap = mcolors.LinearSegmentedColormap.from_list(\"Oranges_subset\", palette2)\n",
    "\n",
    "# Create ScalarMappables for each colormap\n",
    "norm = mcolors.LogNorm(vmin=sizes.min() / 2, vmax=sizes.max() * 2)\n",
    "blue_sm = plt.cm.ScalarMappable(cmap=blue_cmap, norm=norm)\n",
    "orange_sm = plt.cm.ScalarMappable(cmap=orange_cmap, norm=norm)\n",
    "\n",
    "blue_sm.set_array([])\n",
    "orange_sm.set_array([])\n",
    "\n",
    "# Add the colorbars side by side on the right\n",
    "cbar_orange = plt.colorbar(orange_sm, ax=ax, location=\"right\", pad=-0.08)\n",
    "\n",
    "cbar_blue = plt.colorbar(blue_sm, ax=ax, location=\"right\", pad=0.04)\n",
    "\n",
    "# Add a single shared label\n",
    "cbar_orange.ax.set_ylabel(\"# Of Latents\",labelpad=5)\n",
    "\n",
    "# Set matching ticks\n",
    "for cbar in [cbar_blue, cbar_orange]:\n",
    "    cbar.set_ticks(sizes)\n",
    "\n",
    "cbar_orange.set_ticklabels([f\"$2^{{{int(np.log2(size))}}}$\" for size in sizes])\n",
    "# Add the custom legend\n",
    "plt.legend(handles=legend_handles)\n",
    "plt.grid(axis='x')\n",
    "#plt.savefig(\"results/figures/sae_scaling_law_70m_160m.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a98a4f9f-864b-46b1-ae91-91a31b397dcc",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4917f51a-f99f-4a2e-b3c6-093310dd83f3",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
