{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import datetime\n",
    "import pandas as pd \n",
    "import wandb\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import torch\n",
    "import numpy as np\n",
    "from scipy import stats\n",
    "from tqdm import tqdm\n",
    "sns.set_theme(style=\"whitegrid\")\n",
    "sns.set_context(\"paper\", font_scale=1.5, rc={\"lines.linewidth\": 2.5})\n",
    "\n",
    "# my api key for accessing old wandb runs\n",
    "os.environ['WANDB_API_KEY'] = '1159dda0d0566b72d5cd71464a06ff6b73efb455'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def average_df_series(series):\n",
    "    # Initialize an accumulator DataFrame for sums\n",
    "    sum_df = None\n",
    "    # Number of DataFrames in the series\n",
    "    n = len(series)\n",
    "    \n",
    "    for df in series:\n",
    "        # If sum_df is None, initialize it with the first DataFrame\n",
    "        if sum_df is None:\n",
    "            sum_df = df.copy()\n",
    "        else:\n",
    "            # Otherwise, add the current DataFrame to the sum_df\n",
    "            sum_df += df\n",
    "    \n",
    "    # Divide each value by n to get the average\n",
    "    average_df = sum_df / n\n",
    "    return average_df\n",
    "\n",
    "\n",
    "def get_runs(filters, proj='struct_mlp', tune_on='test', avg_seeds=True, history=True):\n",
    "    api = wandb.Api()\n",
    "    runs = api.runs(proj, filters=filters, order='-created_at')\n",
    "    summary_list, config_list, name_list = [], [], []\n",
    "    history_list = []\n",
    "    for run in runs: \n",
    "        summary_list.append(run.summary._json_dict)\n",
    "\n",
    "        # .config contains the hyperparameters.\n",
    "        #  We remove special values that start with _.\n",
    "        config_list.append(\n",
    "            {k: v for k,v in run.config.items()\n",
    "            if not k.startswith('_')})\n",
    "\n",
    "        # .name is the human-readable name of the run.\n",
    "        name_list.append(run.name)\n",
    "        if history:\n",
    "            if 'train_loss_avg' in run.summary._json_dict:\n",
    "                hist = run._sampled_history(keys=['current_compute', 'train_loss_avg'], x_axis='step', samples=1000)\n",
    "            else:\n",
    "                hist = run._sampled_history(keys=['compute', 'val/loss'], x_axis='iter', samples=1000)        \n",
    "            hist_df = pd.DataFrame(hist)\n",
    "            history_list.append(hist_df)\n",
    "\n",
    "    runs_dict = {\n",
    "        \"summary\": summary_list,\n",
    "        \"config\": config_list,\n",
    "        \"name\": name_list,\n",
    "    }\n",
    "    if history:\n",
    "        runs_dict['history'] = history_list\n",
    "    runs_df = pd.DataFrame(runs_dict)\n",
    "\n",
    "\n",
    "    runs_df = runs_df[runs_df['summary'].apply(lambda x: x != {})]\n",
    "    keys = ['width', 'depth', 'struct', 'layers', 'cola_params', 'cola_flops', 'lr', 'batch_size', 'tt_rank', 'expr', '_runtime', 'd_model', 'num_ffn_experts', 'num_experts']\n",
    "    for key in keys:\n",
    "        # For other keys, just extract the value if it exists\n",
    "        runs_df[key] = runs_df['config'].apply(lambda x: x[key] if key in x else -1)\n",
    "        \n",
    "    for key in ['train_loss_avg', 'val/loss', 'step']:\n",
    "        runs_df[key] = runs_df['summary'].apply(lambda x: x[key] if key in x else -1)\n",
    "    runs_df['runtime'] = runs_df['_runtime'].apply(lambda x: x / 60) # convert to minutes\n",
    "    \n",
    "    \n",
    "    # delete name summary and config\n",
    "    runs_df = runs_df.drop(columns=['summary', 'config', 'name'])\n",
    "\n",
    "    # Everything else being equal, only keep the best run\n",
    "    if tune_on is not None:\n",
    "        if 'loss' in tune_on:\n",
    "            idx = runs_df.groupby(keys)[tune_on].idxmin()\n",
    "        else:\n",
    "            idx = runs_df.groupby(keys)[tune_on].idxmax()\n",
    "        runs_df = runs_df.loc[idx]\n",
    "        \n",
    "    # average over seeds\n",
    "    if avg_seeds:\n",
    "        keys = [k for k in keys if k != 'seed']\n",
    "        numeric_cols = runs_df.select_dtypes(include=[np.number]).columns.tolist()\n",
    "        agg_dict = {col: 'mean' for col in numeric_cols}\n",
    "        agg_dict.update({col: 'first' for col in keys})\n",
    "        if history:\n",
    "            agg_dict['history'] = average_df_series\n",
    "        \n",
    "        runs_df = runs_df.groupby(keys).agg(agg_dict).reset_index(drop=True)\n",
    "    # Tune LR\n",
    "    if tune_on is not None:\n",
    "        # remove lr\n",
    "        keys = [k for k in keys if k != 'lr']\n",
    "        if 'loss' in tune_on:\n",
    "            idx = runs_df.groupby(keys)[tune_on].idxmin()\n",
    "        else:\n",
    "            idx = runs_df.groupby(keys)[tune_on].idxmax()\n",
    "        runs_df = runs_df.loc[idx]\n",
    "    return runs_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_2d_vec(row):\n",
    "    expr = row['expr'][1:-1]\n",
    "    vec = expr.split('-')\n",
    "    # turn into float\n",
    "    vec = [float(x) for x in vec]\n",
    "    # reorder\n",
    "    if vec[0] > 0:\n",
    "        fixed_indices = [0, 1, 2, 6, 3, 4, 5]\n",
    "    else:\n",
    "        fixed_indices = [1, 0, 2, 3, 5, 4, 6]\n",
    "    vec = [vec[i] for i in fixed_indices]\n",
    "    return tuple(vec)\n",
    "\n",
    "def cleanup_expr(row):\n",
    "    expr = row['expr'][1:-1].replace('-', ',')\n",
    "    coeffs = expr.split(',')\n",
    "    if float(coeffs[0]) > 0:\n",
    "        fixed_indices = [0, 1, 2, 6, 3, 4, 5]\n",
    "    else:\n",
    "        fixed_indices = [1, 0, 2, 3, 5, 4, 6]\n",
    "    coeffs = [coeffs[i] for i in fixed_indices]\n",
    "    expr = ','.join(coeffs)\n",
    "    fraction_map = {0.25: '\\\\frac{1}{4}', 0.33: '\\\\frac{1}{3}', 0.5: '\\\\frac{1}{2}', 0.67: '\\\\frac{2}{3}', 0.75: '\\\\frac{3}{4}'}\n",
    "    for key, value in fraction_map.items():\n",
    "        expr = expr.replace(str(key), value)\n",
    "    expr = '$(' + expr + ')$'\n",
    "    if 'rms_norm' in row['struct']:\n",
    "        expr = expr + '-RMS'\n",
    "    elif 'norm' in row['struct']:\n",
    "        expr = expr + '-N'\n",
    "    return expr\n",
    "\n",
    "def get_omega(row):\n",
    "    vec = row['vec'] # alpha/a beta/b gamma/c rho/r delta/d eps/e phi/f\n",
    "    # omega = min(a+d, b+e) - min(a, e)\n",
    "    omega = min(vec[0] + vec[4], vec[1] + vec[5]) - min(vec[0], vec[5])\n",
    "    return omega\n",
    "\n",
    "def get_sigma(row):\n",
    "    vec = row['vec'] # a b c r d e f\n",
    "    # sigma = min(r + d + f - a, b + c + r - e)\n",
    "    return min(vec[3] + vec[4] + vec[6] - vec[0], vec[1] + vec[2] + vec[3] - vec[5])\n",
    "\n",
    "def get_nu(row):\n",
    "    vec = row['vec'] # a b c r d e f\n",
    "    # nu = 1 + r - min(a,e)\n",
    "    return 1 + vec[3] - min(vec[0], vec[5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "input_dim = 8\n",
    "filters = {\n",
    "    \"state\": \"finished\",\n",
    "    \"config.layers\": {\"$in\": ['all_but_last', 'intermediate']},\n",
    "    \"config.model\": {\"$eq\": 'MLP'},\n",
    "    \"config.input_dim\": input_dim,\n",
    "    \"config.expr\": {\"$ne\": '(1-0-0-0.5-0-1-0)'},\n",
    "    \"config.depth\": 3,\n",
    "    \"config.steps\": 1e6,\n",
    "    # \"config.cola_flops\": {\"$lt\": 1e7},\n",
    "}\n",
    "plot_name = f'synth_d{input_dim}'\n",
    "\n",
    "runs = get_runs(filters, proj=f'ap-team/synth', tune_on='train_loss_avg', history=False, avg_seeds=True)\n",
    "runs = runs[runs['struct'].isin(['dense', 'simple_ein_vec_norm', 'simple_ein_vec', 'simple_ein_vec_rms_norm'])] \n",
    "ein_runs = runs[runs['struct'].isin(['simple_ein_vec', 'simple_ein_vec_norm', 'simple_ein_vec_rms_norm'])]\n",
    "ein_runs = ein_runs[ein_runs['layers'] == 'intermediate']\n",
    "ein_runs['struct'] = ein_runs.apply(cleanup_expr, axis=1)\n",
    "ein_runs['vec'] = ein_runs.apply(get_2d_vec, axis=1)\n",
    "# vecs = [(0.5, 0, 0.5, 0, 0, 0.5, 0.5), (0.5, 0, 0.5, 0.25, 0, 0.5, 0.5)][:1]\n",
    "vecs = [(0.5, 0, 0.5, 0, 0, 0.5, 0.5), (0.67, 0, 0.33, 0, 0, 0.67, 0.33), (0.67, 0, 0.33, 0.25, 0, 0.67, 0.33)][2:]\n",
    "ein_runs = ein_runs[ein_runs['vec'].isin(vecs)]\n",
    "\n",
    "runs['vec'] = [(1, 0, 0, 0, 1, 0, 0)] * runs.shape[0]\n",
    "runs = pd.concat([runs[runs['struct'].isin(['dense'])], ein_runs])\n",
    "runs = runs[runs['step'] == 999000]\n",
    "# capitalize dense\n",
    "runs['struct'] = runs['struct'].apply(lambda x: x.capitalize())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib.ticker import ScalarFormatter\n",
    "import matplotlib.ticker as ticker\n",
    "\n",
    "xname = 'cola_flops'\n",
    "yname = 'train_loss_avg'\n",
    "ylabel = 'Loss'\n",
    "fit = input_dim in [8, 64]\n",
    "# log log scale, scatter plot\n",
    "sns.set(style=\"whitegrid\", font_scale=2.0, rc={\"lines.linewidth\": 3.0})\n",
    "# unique structure\n",
    "structs = runs['struct'].unique()\n",
    "pallette = sns.color_palette(\"hls\", n_colors=len(structs))\n",
    "\n",
    "\n",
    "plt.figure(dpi=100, figsize=(8, 6))\n",
    "ax = sns.scatterplot(data=runs, x=xname, y=yname, hue='struct', markers=True, hue_order=structs, s=100, palette=pallette, alpha=0.9, style='lr')\n",
    "# ax = sns.lineplot(data=runs, x=xname, y=yname, hue='struct', hue_order=structs, palette=pallette, linewidth=2, alpha=0.9, marker='o', markersize=10)\n",
    "if fit:\n",
    "    # fit a line on the log log scale, for each struct\n",
    "    slopes = []\n",
    "    halfwidths = []\n",
    "    for struct, color in zip(structs, pallette):\n",
    "        struct_runs = runs[runs['struct'] == struct]\n",
    "        x = np.log(struct_runs[xname])\n",
    "        y = np.log(struct_runs[yname])\n",
    "        m, b, r_value, p_value, std_err = stats.linregress(x,y)\n",
    "        alpha = 0.05  # 95% confidence interval\n",
    "        slope_conf_interval = std_err\n",
    "        slopes.append(-m)\n",
    "        halfwidths.append(slope_conf_interval)\n",
    "        plt.plot(struct_runs[xname], np.exp(m * x + b), color=color, linestyle='--', linewidth=2, alpha=0.75)\n",
    "plt.ylabel(ylabel)\n",
    "plt.xlabel('FLOPs' if xname == 'cola_flops' else 'Params')\n",
    "plt.xscale('log')\n",
    "plt.yscale('log')\n",
    "# plt.yticks([1e-2, 1e-3])\n",
    "\n",
    "# yticks as plain numbers (not scientific notation)\n",
    "# plt.gca().yaxis.set_major_formatter(ScalarFormatter())\n",
    "# plt.gca().yaxis.set_minor_formatter(ScalarFormatter())\n",
    "\n",
    "ax.grid(which='minor', axis='y', linestyle='-', linewidth=0.5)\n",
    "handles, labels = ax.get_legend_handles_labels()\n",
    "# remove legend\n",
    "ax.get_legend().remove()\n",
    "# # legend outside\n",
    "# plt.legend(handles, labels, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., ncol=2, fontsize=20)\n",
    "# plt.title(f'MLP on {ds.replace(\"cifar10\", \"CIFAR-10\")}')\n",
    "# plt.tight_layout()\n",
    "plt.savefig(f'./figures/{plot_name}_{xname}.pdf', bbox_inches='tight')\n",
    "\n",
    "# legend as a separate figure\n",
    "legend_fig = plt.figure(figsize=(8, 1))  # Adjuxst size as needed\n",
    "ax_legend = legend_fig.add_subplot(111)\n",
    "ax_legend.legend(handles, labels, loc='center', ncol=len(labels) // 2, fontsize=15)\n",
    "ax_legend.axis('off')  # Hide axes    \n",
    "plt.tight_layout()\n",
    "plt.savefig(f'./figures/{plot_name}_legend.pdf', bbox_inches='tight')\n",
    "# barplot for slopes\n",
    "if fit:\n",
    "    plt.figure(dpi=100, figsize=(8, 6))\n",
    "    # plt.figure(dpi=10, figsize=(4, 6))\n",
    "    ax = sns.barplot(x=structs, y=slopes, palette=pallette)\n",
    "    x_coords = [p.get_x() + 0.5 * p.get_width() for p in ax.patches]\n",
    "    y_coords = [p.get_height() for p in ax.patches]\n",
    "    ax.errorbar(x=x_coords, y=y_coords, yerr=halfwidths, fmt=\"none\", c=\"k\")\n",
    "    # horizontal line at desired slope\n",
    "    dense_slope = slopes[0]\n",
    "    plt.axhline(y=dense_slope, color='black', linestyle='--', linewidth=2)\n",
    "    plt.ylabel(r'Exponent $p$')\n",
    "    # rotate x labels\n",
    "    plt.xticks(rotation=70, fontsize=12, ha='center')\n",
    "    # remove xticks\n",
    "    # plt.xticks([], [])\n",
    "    # plt.title(f'MLP Exponent on {ds.replace(\"cifar10\", \"CIFAR-10\")}')\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(f'./figures/{plot_name}_{xname}_slopes.pdf', bbox_inches='tight')\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame({'struct': structs, 'slope': slopes, 'halfwidth': halfwidths})\n",
    "# add vec of each strcut to df\n",
    "df['vec'] = runs[runs['struct'].isin(structs)]['vec'].unique()\n",
    "df['omega'] = df.apply(get_omega, axis=1)\n",
    "df['sigma'] = df.apply(get_sigma, axis=1)\n",
    "df['nu'] = df.apply(get_nu, axis=1)\n",
    "\n",
    "# Create a custom colormap\n",
    "cmap = sns.color_palette('viridis', as_cmap=True)\n",
    "\n",
    "# Plot slope on (omega, sigma) plane\n",
    "plt.figure(dpi=100, figsize=(8, 6))\n",
    "ax = sns.scatterplot(data=df, x='omega', y='sigma', hue='slope', s=100, alpha=0.75, palette=cmap, legend=False)\n",
    "plt.xlabel(r'$\\omega$')\n",
    "plt.ylabel(r'$\\sigma$')\n",
    "\n",
    "# Add a colorbar\n",
    "norm = plt.Normalize(df['slope'].min(), df['slope'].max())\n",
    "sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)\n",
    "sm.set_array([])\n",
    "cbar = plt.colorbar(sm, ax=ax)\n",
    "cbar.set_label('$p$', labelpad=10)\n",
    "\n",
    "# Adjust the colorbar ticks and labels\n",
    "cbar.set_ticks([df['slope'].min(), df['slope'].max()])\n",
    "cbar.set_ticklabels([f'{df[\"slope\"].min():.2f}', f'{df[\"slope\"].max():.2f}'])\n",
    "\n",
    "# Add a frame to the colorbcbar.outline.set_visible(True)\n",
    "cbar.outline.set_linewidth(0.5)\n",
    "cbar.outline.set_edgecolor('black')\n",
    "\n",
    "# Adjust the plot layout to make space for the colorbar\n",
    "plt.tight_layout()\n",
    "plt.savefig(f'./figures/{plot_name}_omega_sigma.pdf', bbox_inches='tight')\n",
    "\n",
    "# Plot slope on (omega, sigma) plane\n",
    "plt.figure(dpi=100, figsize=(8, 6))\n",
    "ax = sns.scatterplot(data=df, x='omega', y='nu', hue='slope', s=100, alpha=0.75, palette=cmap, legend=False)\n",
    "plt.xlabel(r'$\\omega$')\n",
    "plt.ylabel(r'$\\nu$')\n",
    "\n",
    "# Add a colorbar\n",
    "norm = plt.Normalize(df['slope'].min(), df['slope'].max())\n",
    "sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)\n",
    "sm.set_array([])\n",
    "cbar = plt.colorbar(sm, ax=ax)\n",
    "cbar.set_label('$p$', labelpad=10)\n",
    "\n",
    "# Adjust the colorbar ticks and labels\n",
    "cbar.set_ticks([df['slope'].min(), df['slope'].max()])\n",
    "cbar.set_ticklabels([f'{df[\"slope\"].min():.2f}', f'{df[\"slope\"].max():.2f}'])\n",
    "\n",
    "# Add a frame to the colorbcbar.outline.set_visible(True)\n",
    "cbar.outline.set_linewidth(0.5)\n",
    "cbar.outline.set_edgecolor('black')\n",
    "\n",
    "# Adjust the plot layout to make space for the colorbar\n",
    "plt.tight_layout()\n",
    "plt.savefig(f'./figures/{plot_name}_omega_nu.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from scipy.optimize import linprog, curve_fit\n",
    "\n",
    "def power_law_with_step(x, c_flops, p_flops, c_step, p_step):\n",
    "    cola_flops, step = x\n",
    "    return c_flops * cola_flops**(-p_flops) + c_step * step**(-p_step)\n",
    "\n",
    "def fit_power_law_with_step(df, min_step=0, max_step=1e32, flops_scale=1e5, step_scale=1e6, loss_scale=1e-2):\n",
    "    valid = (df['step'] >= min_step) & (df['step'] <= max_step)\n",
    "    cola_flops = df['cola_flops'][valid] / flops_scale\n",
    "    step = df['step'][valid] / step_scale\n",
    "    loss = df['loss'][valid] / loss_scale\n",
    "    \n",
    "    input_data = np.vstack((cola_flops, step))\n",
    "    target_data = loss\n",
    "    \n",
    "    popt, _ = curve_fit(power_law_with_step, input_data, target_data)\n",
    "    c_flops_opt, p_flops_opt, c_step_opt, p_step_opt = popt\n",
    "    \n",
    "    # Adjust the parameters\n",
    "    c_flops_opt *= loss_scale * flops_scale**(p_flops_opt)\n",
    "    c_step_opt *= loss_scale * step_scale**(p_step_opt)\n",
    "    \n",
    "    return c_flops_opt, p_flops_opt, c_step_opt, p_step_opt\n",
    "\n",
    "def fit_power_law(df, min_compute=0, max_compute=1e32):\n",
    "    log_compute = np.log(df['compute'])\n",
    "    log_loss = np.log(df['loss'])\n",
    "    \n",
    "    # Filter out points that are outside the range\n",
    "    valid = (df['compute'] >= min_compute) & (df['compute'] <= max_compute)\n",
    "    log_compute = log_compute[valid]\n",
    "    log_loss = log_loss[valid]\n",
    "    \n",
    "    # Set up the objective coefficients\n",
    "    c = -np.array([np.sum(-log_compute), len(log_compute)])\n",
    "    \n",
    "    # Set up the constraint matrix and vector\n",
    "    A_ub = np.column_stack((-log_compute, np.ones(len(log_compute))))\n",
    "    b_ub = log_loss\n",
    "    \n",
    "    # Solve the linear programming problem\n",
    "    result = linprog(c, A_ub=A_ub, b_ub=b_ub, method='highs', bounds=(None, None))\n",
    "    \n",
    "    # Extract the optimal parameters\n",
    "    a_opt, b_opt = result.x\n",
    "    \n",
    "    return a_opt, b_opt\n",
    "\n",
    "def get_history_df(df, min_step=0, max_step=1e6, min_compute=0, max_compute=1e32):\n",
    "    run_keys = ['cola_flops', 'cola_params', 'depth', 'width', 'struct', 'lr']\n",
    "    hist_points = []\n",
    "    \n",
    "    # for run\n",
    "    for _, row in df.iterrows():\n",
    "        hist_dict = {k: row[k] for k in run_keys}\n",
    "        \n",
    "        # for step\n",
    "        for step, compute, loss in row['history']:\n",
    "            if step < min_step or step > max_step:\n",
    "                continue\n",
    "            \n",
    "            hist_point = hist_dict.copy()\n",
    "            hist_point.update({'step': step, 'compute': compute, 'loss': loss})\n",
    "            hist_point['log_cola_flops'] = np.log(hist_point['cola_flops'])\n",
    "            hist_points.append(hist_point)\n",
    "    \n",
    "    hist_df = pd.DataFrame(hist_points)\n",
    "    \n",
    "    # fit power law\n",
    "    slope, intercept = fit_power_law(hist_df, min_compute, max_compute)\n",
    "    \n",
    "    # fit power law with step\n",
    "    c_flops_opt, p_flops_opt, c_step_opt, p_step_opt = fit_power_law_with_step(hist_df, min_step, max_step)\n",
    "    \n",
    "    return hist_df, slope, intercept, c_flops_opt, p_flops_opt, c_step_opt, p_step_opt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_dim = 8\n",
    "filters = {\n",
    "    \"config.layers\": {\"$in\": ['all_but_last', 'intermediate']},\n",
    "    \"config.model\": {\"$eq\": 'MLP'},\n",
    "    \"config.input_dim\": input_dim,\n",
    "    \"config.depth\": 3,\n",
    "    \"config.steps\": 1e6,\n",
    "    \"config.struct\": {\"$in\": ['dense', 'simple_ein_vec_norm']},\n",
    "    \"config.lr\": 3e-3,\n",
    "    # \"config.cola_flops\": {\"$lt\": 1e7},\n",
    "}\n",
    "plot_name = f'synth_d{input_dim}'\n",
    "\n",
    "runs = get_runs(filters, proj=f'ap-team/synth', tune_on='train_loss_avg', history=True, avg_seeds=True)\n",
    "runs = runs[runs['struct'].isin(['dense', 'simple_ein_vec_norm'])] \n",
    "ein_runs = runs[runs['struct'].isin(['simple_ein_vec_norm'])]\n",
    "ein_runs = ein_runs[ein_runs['layers'] == 'intermediate']\n",
    "ein_runs['vec'] = ein_runs.apply(get_2d_vec, axis=1)\n",
    "ein_runs['struct'] = ein_runs.apply(cleanup_expr, axis=1)\n",
    "runs['vec'] = [(1, 0, 0, 0, 1, 0, 0)] * runs.shape[0]\n",
    "runs = pd.concat([runs[runs['struct'] == 'dense'], ein_runs])\n",
    "# runs = runs[runs['step'] == 999000]\n",
    "runs['struct'] = runs['struct'].apply(lambda x: x.capitalize())\n",
    "runs['omega'] = runs.apply(get_omega, axis=1)\n",
    "runs['sigma'] = runs.apply(get_sigma, axis=1)\n",
    "runs['nu'] = runs.apply(get_nu, axis=1)\n",
    "dense_runs = runs[runs['struct'] == 'Dense']\n",
    "ein_runs = runs[runs['struct'] != 'Dense']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.set(style=\"whitegrid\", font_scale=2.0, rc={\"lines.linewidth\": 3.0})\n",
    "fit = input_dim == 8\n",
    "MIN_STEP = 3e3\n",
    "MAX_STEP = 1e5\n",
    "MIN_COMPUTE = 1e12\n",
    "MAX_COMPUTE = 1e16  # 3e15\n",
    "\n",
    "# set palette\n",
    "sns.set_palette(\"Set1\")\n",
    "dense_color = 'C1'\n",
    "struct_color = 'C0'\n",
    "\n",
    "\n",
    "dense_hist_df, dense_slope, dense_intercept, dense_c_flops_opt, dense_p_flops_opt, dense_c_step_opt, dense_p_step_opt = get_history_df(dense_runs, min_step=MIN_STEP, max_step=MAX_STEP, min_compute=MIN_COMPUTE, max_compute=MAX_COMPUTE)\n",
    "struct='dense'\n",
    "plt.figure(dpi=100, figsize=(8, 6))\n",
    "step_range = np.linspace(MIN_STEP, MAX_STEP, 100)\n",
    "sns.lineplot(data=dense_hist_df, x='step', y='loss', hue='width', markers=True, alpha=0.75, palette=[dense_color], legend=False)\n",
    "for width in dense_hist_df['width'].unique():\n",
    "    width_df = dense_hist_df[dense_hist_df['width'] == width]\n",
    "    cola_flops = width_df['cola_flops'].iloc[0]\n",
    "    y_fit_dense = dense_c_flops_opt * cola_flops**(-dense_p_flops_opt) + dense_c_step_opt * step_range**(-dense_p_step_opt)\n",
    "    plt.plot(step_range, y_fit_dense, linestyle='--', linewidth=2, alpha=0.75, color=dense_color)\n",
    "plt.plot([], [], color=dense_color, linestyle='-', linewidth=2, alpha=1, label='Dense')\n",
    "plt.ylim(2e-3, 1e-1)\n",
    "plt.ylabel('Loss')\n",
    "plt.xlabel('Step')\n",
    "plt.yscale('log')\n",
    "plt.xscale('log')\n",
    "ax.grid(which='minor', axis='y', linestyle='-', linewidth=0.5)\n",
    "plt.title(f'{struct}')\n",
    "plt.legend()\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(f'./figures/{plot_name}_{struct}_step.pdf', bbox_inches='tight')\n",
    "\n",
    "\n",
    "structs = ['Dense']\n",
    "compute_opt_expos = [dense_slope]\n",
    "model_flops_expos = [dense_p_flops_opt]\n",
    "step_expos = [dense_p_step_opt]\n",
    "vecs = [(1, 0, 0, 0, 1, 0, 0)]\n",
    "\n",
    "for struct in ein_runs['struct'].unique():\n",
    "    plt.figure(dpi=100, figsize=(8, 6))\n",
    "    df = ein_runs[ein_runs['struct'] == struct]\n",
    "    vec = df['vec'].iloc[0]\n",
    "    \n",
    "    if vec[0] == 0:\n",
    "        continue\n",
    "    \n",
    "    if df.shape[0] < 3:\n",
    "        continue\n",
    "    \n",
    "    if vec == (1/2, 1/2, 0, 0, 1/2, 1/2, 0):  # Kron should be fitted in the smaller range\n",
    "        hist_df, slope, intercept, c_flops_opt, p_flops_opt, c_step_opt, p_step_opt = get_history_df(df, min_step=MIN_STEP, max_step=MAX_STEP, min_compute=MIN_COMPUTE, max_compute=1e14)\n",
    "    else:\n",
    "        hist_df, slope, intercept, c_flops_opt, p_flops_opt, c_step_opt, p_step_opt = get_history_df(df, min_step=MIN_STEP, max_step=MAX_STEP, min_compute=MIN_COMPUTE, max_compute=MAX_COMPUTE)\n",
    "    \n",
    "    compute_opt_expos.append(slope)\n",
    "    model_flops_expos.append(p_flops_opt)\n",
    "    step_expos.append(p_step_opt)\n",
    "    structs.append(struct)\n",
    "    vecs.append(vec)\n",
    "    \n",
    "    ax = sns.lineplot(data=hist_df, x='step', y='loss', hue='width', markers=True, alpha=0.75, palette=[struct_color], legend=False)\n",
    "    # plot power law with step fit for each width\n",
    "    step_range = np.linspace(MIN_STEP, MAX_STEP, 100)\n",
    "    for width in hist_df['width'].unique():\n",
    "        width_df = hist_df[hist_df['width'] == width]\n",
    "        cola_flops = width_df['cola_flops'].iloc[0]\n",
    "        y_fit_struct = c_flops_opt * cola_flops**(-p_flops_opt) + c_step_opt * step_range**(-p_step_opt)\n",
    "        plt.plot(step_range, y_fit_struct, linestyle='--', linewidth=2, alpha=0.75, color=struct_color)\n",
    "    \n",
    "    plt.plot([], [], color=struct_color, linestyle='-', linewidth=2, alpha=1, label='Struct')\n",
    "    plt.ylim(2e-3, 1e-1)\n",
    "    plt.ylabel('Loss')\n",
    "    plt.xlabel('Step')\n",
    "    plt.yscale('log')\n",
    "    plt.xscale('log')\n",
    "    ax.grid(which='minor', axis='y', linestyle='-', linewidth=0.5)\n",
    "    plt.title(f'{struct}')\n",
    "    plt.legend()\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    plt.savefig(f'./figures/{plot_name}_{struct}_step.pdf', bbox_inches='tight')\n",
    "    \n",
    "    plt.figure(dpi=100, figsize=(8, 6))\n",
    "    \n",
    "    # plot data\n",
    "    ax = sns.lineplot(data=hist_df, x='compute', y='loss', hue='width', markers=True, alpha=0.5, palette=[struct_color], linewidth=1.5, legend=False, linestyle='-')\n",
    "    sns.lineplot(data=dense_hist_df, x='compute', y='loss', hue='width', markers=True, alpha=0.5, palette=[dense_color], linewidth=1.5, legend=False, linestyle='-')\n",
    "    ax.grid(which='minor', axis='y', linestyle='-', linewidth=0.5)\n",
    "    \n",
    "    # plot fits\n",
    "    if fit:\n",
    "        x = np.linspace(MIN_COMPUTE, MAX_COMPUTE, 100)\n",
    "        y = np.exp(-dense_slope * np.log(x) + dense_intercept)\n",
    "        plt.plot(x, y, color=dense_color, linestyle='--', linewidth=2, alpha=1)\n",
    "        y = np.exp(-slope * np.log(x) + intercept)\n",
    "        plt.plot(x, y, color=struct_color, linestyle='--', linewidth=2, alpha=1)\n",
    "    \n",
    "    # manually add legend\n",
    "    plt.plot([], [], color=dense_color, linestyle='-', linewidth=2, alpha=1, label='Dense')\n",
    "    plt.plot([], [], color=struct_color, linestyle='-', linewidth=2, alpha=1, label='Struct')\n",
    "    \n",
    "    plt.ylabel('Loss')\n",
    "    plt.xlabel('Compute')\n",
    "    plt.yscale('log')\n",
    "    plt.xscale('log')\n",
    "    plt.legend()\n",
    "    plt.title(rf'{struct}, $\\omega={df[\"omega\"].iloc[0]:.2f}, \\nu={df[\"nu\"].iloc[0]:.2f}, \\sigma={df[\"sigma\"].iloc[0]:.2f}$', fontsize=20)\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(f'./figures/{plot_name}_{vec}_compute.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame({'struct': structs, 'vec': vecs, 'p_compute': compute_opt_expos, 'p_flops': model_flops_expos, 'p_step': step_expos})\n",
    "# fits not good\n",
    "df = df[df['vec'] != (0.5, 0.5, 0, 0.5, 0.5, 0.5, 0)]\n",
    "df = df[df['vec'] != (1, 0, 0, 0.5, 0, 1, 0)]\n",
    "df = df[df['vec'] != (1, 0, 0, 0.75, 0, 1, 0)] \n",
    "\n",
    "df['p_compute_theory'] = (df['p_flops'] * df['p_step']) / (df['p_flops'] + df['p_step'])\n",
    "\n",
    "pallette = sns.color_palette(\"hls\", n_colors=len(structs))\n",
    "\n",
    "plt.figure(dpi=100, figsize=(8, 6))\n",
    "ax = sns.barplot(x='struct', y='p_compute', data=df, palette=pallette)\n",
    "plt.axhline(y=df['p_compute'].iloc[0], color='black', linestyle='--', linewidth=2)\n",
    "plt.ylabel(r'$p_C$')\n",
    "# rotate x labels\n",
    "plt.xticks(rotation=70, fontsize=12, ha='center')\n",
    "plt.tight_layout()\n",
    "plt.savefig(f'./figures/{plot_name}_compute_slopes.pdf', bbox_inches='tight')\n",
    "\n",
    "plt.figure(dpi=100, figsize=(8, 6))\n",
    "ax = sns.barplot(x='struct', y='p_compute_theory', data=df, palette=pallette)\n",
    "plt.axhline(y=df['p_compute_theory'].iloc[0], color='black', linestyle='--', linewidth=2)\n",
    "plt.ylabel(r'$\\hat{p}_C$')\n",
    "# rotate x labels\n",
    "plt.xticks(rotation=70, fontsize=12, ha='center')\n",
    "plt.tight_layout()\n",
    "plt.savefig(f'./figures/{plot_name}_compute_theory_slopes.pdf', bbox_inches='tight')\n",
    "\n",
    "plt.figure(dpi=100, figsize=(8, 6))\n",
    "ax = sns.barplot(x='struct', y='p_flops', data=df, palette=pallette)\n",
    "plt.axhline(y=df['p_flops'].iloc[0], color='black', linestyle='--', linewidth=2)\n",
    "plt.ylabel(r'$p_F$')\n",
    "# rotate x labels\n",
    "plt.xticks(rotation=70, fontsize=12, ha='center')\n",
    "plt.tight_layout()\n",
    "plt.savefig(f'./figures/{plot_name}_flops_slopes.pdf', bbox_inches='tight')\n",
    "\n",
    "plt.figure(dpi=100, figsize=(8, 6))\n",
    "ax = sns.barplot(x='struct', y='p_step', data=df, palette=pallette)\n",
    "plt.axhline(y=df['p_step'].iloc[0], color='black', linestyle='--', linewidth=2)\n",
    "plt.ylabel(r'$p_T$')\n",
    "# rotate x labels\n",
    "plt.xticks(rotation=70, fontsize=12, ha='center')\n",
    "plt.tight_layout()\n",
    "plt.savefig(f'./figures/{plot_name}_step_slopes.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame({'struct': structs, 'slope': exponents, 'vec': vecs})\n",
    "# # add dense\n",
    "# df = pd.concat([df, pd.DataFrame({'struct': ['Dense'], 'slope': [dense_slope], 'vec': [(1, 0, 0, 0, 1, 0, 0)]})])\n",
    "# get vec, omega, sigma, nu\n",
    "df['omega'] = df.apply(get_omega, axis=1)\n",
    "df['sigma'] = df.apply(get_sigma, axis=1)\n",
    "df['nu'] = df.apply(get_nu, axis=1)\n",
    "# negate the slopes\n",
    "df['slope'] = -df['slope']\n",
    "\n",
    "# Create a custom colormap\n",
    "cmap = sns.color_palette('viridis', as_cmap=True)\n",
    "\n",
    "# Plot slope on (omega, sigma) plane\n",
    "plt.figure(dpi=100, figsize=(8, 6))\n",
    "ax = sns.scatterplot(data=df, x='omega', y='sigma', hue='slope', s=100, alpha=0.75, palette=cmap, legend=False)\n",
    "plt.xlabel(r'$\\omega$')\n",
    "plt.ylabel(r'$\\sigma$')\n",
    "\n",
    "# Add a colorbar\n",
    "norm = plt.Normalize(df['slope'].min(), df['slope'].max())\n",
    "sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)\n",
    "sm.set_array([])\n",
    "cbar = plt.colorbar(sm, ax=ax)\n",
    "cbar.set_label('$p$', labelpad=10)\n",
    "\n",
    "# Adjust the colorbar ticks and labels\n",
    "cbar.set_ticks([df['slope'].min(), df['slope'].max()])\n",
    "cbar.set_ticklabels([f'{df[\"slope\"].min():.2f}', f'{df[\"slope\"].max():.2f}'])\n",
    "\n",
    "# Add a frame to the colorbcbar.outline.set_visible(True)\n",
    "cbar.outline.set_linewidth(0.5)\n",
    "cbar.outline.set_edgecolor('black')\n",
    "\n",
    "# Adjust the plot layout to make space for the colorbar\n",
    "plt.tight_layout()\n",
    "plt.savefig(f'./figures/{plot_name}_compute_omega_sigma.pdf', bbox_inches='tight')\n",
    "\n",
    "# Plot slope on (omega, sigma) plane\n",
    "plt.figure(dpi=100, figsize=(8, 6))\n",
    "ax = sns.scatterplot(data=df, x='omega', y='nu', hue='slope', s=100, alpha=0.75, palette=cmap, legend=False)\n",
    "plt.xlabel(r'$\\omega$')\n",
    "plt.ylabel(r'$\\nu$')\n",
    "\n",
    "# Add a colorbar\n",
    "norm = plt.Normalize(df['slope'].min(), df['slope'].max())\n",
    "sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)\n",
    "sm.set_array([])\n",
    "cbar = plt.colorbar(sm, ax=ax)\n",
    "cbar.set_label('$p$', labelpad=10)\n",
    "\n",
    "# Adjust the colorbar ticks and labels\n",
    "cbar.set_ticks([df['slope'].min(), df['slope'].max()])\n",
    "cbar.set_ticklabels([f'{df[\"slope\"].min():.2f}', f'{df[\"slope\"].max():.2f}'])\n",
    "\n",
    "# Add a frame to the colorbcbar.outline.set_visible(True)\n",
    "cbar.outline.set_linewidth(0.5)\n",
    "cbar.outline.set_edgecolor('black')\n",
    "\n",
    "# Adjust the plot layout to make space for the colorbar\n",
    "plt.tight_layout()\n",
    "plt.savefig(f'./figures/{plot_name}_compute_omega_nu.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "filters = {\n",
    "    \"state\": \"finished\",\n",
    "    # \"config.depth\": 3,\n",
    "    # \"config.cola_flops\": {\"$lt\": 1e7},\n",
    "}\n",
    "plot_name = 'moe_gpt'\n",
    "\n",
    "runs = get_runs(filters, proj=f'ap-team/moe_gpt', tune_on='val/loss', history=True, avg_seeds=True)\n",
    "runs = runs[runs['struct'].isin(['dense', 'dense_moe', 'btt_norm', 'btt_norm_moe'])] \n",
    "ein_runs = runs[runs['struct'] != 'dense']\n",
    "runs = pd.concat([runs[runs['struct'] == 'dense'], ein_runs])\n",
    "# runs = runs[runs['step'] == 999000]\n",
    "dense_runs = runs[runs['struct'] == 'dense']\n",
    "ein_runs = runs[runs['struct'] != 'dense']\n",
    "# if struct == 'btt_norm', append tt_rank to the name\n",
    "ein_runs['struct'] = ein_runs.apply(lambda x: x['struct'] + f'_{x[\"tt_rank\"]}' if x['struct'] == 'btt_norm' else x['struct'], axis=1)\n",
    "def name_map(row):\n",
    "    name = row['struct']\n",
    "    num_experts = row['num_experts']\n",
    "    num_ffn_experts = row['num_ffn_experts']\n",
    "    if name == 'dense':\n",
    "        return f'Dense-{num_ffn_experts}FFN'\n",
    "    if name == 'dense_moe':\n",
    "        return f'Dense-{num_experts}'\n",
    "    elif name == 'btt_norm_moe':\n",
    "        return 'BTT-MoE'\n",
    "    elif name.startswith('btt_norm'):\n",
    "        return name.replace('btt_norm_', 'BTT Rank-')\n",
    "    return name\n",
    "ein_runs['struct'] = ein_runs.apply(name_map, axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.set(style=\"whitegrid\", font_scale=2.0, rc={\"lines.linewidth\": 3.0})\n",
    "MIN_STEP = 0\n",
    "MAX_STEP = 3e5\n",
    "MIN_COMPUTE = 1e14\n",
    "MAX_COMPUTE = 3e17 # 3e15 \n",
    "dense_hist_df, dense_slope, dense_intercept = get_history_df(dense_runs, min_step=MIN_STEP, max_step=MAX_STEP, min_compute=MIN_COMPUTE, max_compute=MAX_COMPUTE)\n",
    "\n",
    "# set palette\n",
    "sns.set_palette(\"Set1\")\n",
    "dense_palette = ['C1']\n",
    "struct_pallette = ['C0']\n",
    "for struct in ein_runs['struct'].unique():\n",
    "    plt.figure(dpi=100, figsize=(8, 6))\n",
    "    df = ein_runs[ein_runs['struct'] == struct]\n",
    "    hist_df, slope, intercept = get_history_df(df, min_step=MIN_STEP, max_step=MAX_STEP, min_compute=MIN_COMPUTE, max_compute=MAX_COMPUTE)\n",
    "    print(struct, (np.exp(intercept) / np.exp(dense_intercept)) ** np.abs(1 / slope))\n",
    "\n",
    "    # ax = sns.lineplot(data=hist_df, x='step', y='loss', hue='cola_flops', markers=True, alpha=0.75, palette=struct_pallette)\n",
    "    # sns.lineplot(data=dense_hist_df, x='step', y='loss', hue='cola_flops', markers=True, alpha=0.75, palette=dense_palette)\n",
    "    # plt.ylim(2e-3, 1e-1)\n",
    "    # plt.ylabel('Loss')\n",
    "    # plt.xlabel('Step')\n",
    "    # plt.yscale('log')\n",
    "    # plt.xscale('log')\n",
    "    # ax.grid(which='minor', axis='y', linestyle='-', linewidth=0.5)\n",
    "    # handles, labels = ax.get_legend_handles_labels()\n",
    "    # # remove legend\n",
    "    # ax.get_legend().remove()\n",
    "    # # # legend outside\n",
    "    # plt.title(f'{struct}')\n",
    "    # plt.legend(handles, labels, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)\n",
    "    # # plt.savefig(f'./figures/{plot_name}_{struct}.pdf', bbox_inches='tight')\n",
    "    \n",
    "    plt.figure(dpi=100, figsize=(8, 6))\n",
    "    # plot data\n",
    "    ax = sns.lineplot(data=hist_df, x='compute', y='loss', hue='cola_flops', markers=True, alpha=0.5, palette=struct_pallette, linewidth=1.5, legend=False, linestyle='-')\n",
    "    sns.lineplot(data=dense_hist_df, x='compute', y='loss', hue='cola_flops', markers=True, alpha=0.5, palette=dense_palette, linewidth=1.5, legend=False, linestyle='-')\n",
    "    ax.grid(which='minor', axis='y', linestyle='-', linewidth=0.5)\n",
    "    # plot fits\n",
    "    x = np.linspace(MIN_COMPUTE, MAX_COMPUTE, 100)\n",
    "    y = np.exp(dense_slope * np.log(x) + dense_intercept)\n",
    "    plt.plot(x, y, color=dense_palette[0], linestyle='--', linewidth=2, alpha=1, label='Dense')\n",
    "    y = np.exp(slope * np.log(x) + intercept)\n",
    "    plt.plot(x, y, color=struct_pallette[0], linestyle='--', linewidth=2, alpha=1, label='Struct')\n",
    "    # plt.ylim(2e-3, 1e-1)\n",
    "    plt.ylabel('Loss')\n",
    "    plt.xlabel('Compute')\n",
    "    plt.yscale('log')\n",
    "    plt.xscale('log')\n",
    "    plt.legend()\n",
    "    plt.title(f'{struct}')\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(f'./figures/{plot_name}_{struct}_compute.pdf', bbox_inches='tight')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.13 ('struct')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  },
  "vscode": {
   "interpreter": {
    "hash": "7b092aa812af2b1863eaf59b6fb9ad19cc27e117e4a3131fca0299695416e020"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
