{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from configs import *\n",
    "import pandas as pd\n",
    "from data import *\n",
    "from analysis import *\n",
    "from paper_figures import *\n",
    "from paper_tables import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_figs = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_pickle('data/experiment_results.pickle.xz', compression='xz')\n",
    "df = process_big_df(df.copy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# takes about 2 minutes\n",
    "df_sweep = process_sweep_df(df.query(\"hparams=='sweep'\"))\n",
    "summary_df = perform_main_analysis(df, FIGURE1_CONFIGS)\n",
    "summary_df_att = perform_main_analysis(df, ATTENTION_ACCOUNTING_CONFIGS)\n",
    "summary_df_kaplan_tuned_hparams = perform_main_analysis(df, [('rw', 'tuned', 'long', 'const', 'kaplan', 'train')])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "summary_df_owt2 = perform_main_analysis(df, FIGURE1_CONFIGS_OWT2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Figures"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "figure1(summary_df, save=True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Warmup evidence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "warm_evidence_figure(summary_df, save=save_figs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## IsoFLOP curves"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "isoflop_loss_figure(summary_df, save=save_figs, configs_to_show=None)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Different datasets nd FLOP counts "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "full_results_figure(summary_df_owt2, save=save_figs, configs_to_show=FIGURE1_CONFIGS_OWT2)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "full_results_figure(summary_df, save='rw')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "opt_N_with_attention_figure(summary_df_att,save=True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "full_results_figure(summary_df_kaplan_tuned_hparams,\n",
    "    configs_to_show=[('rw', 'tuned', 'long', 'const', 'kaplan', 'train')], save=save_figs, kaplan_adjusted=True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Accuracy vs, compute"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# takes about a minute\n",
    "config_compute = ('rw', 'tuned', 'short', 'const', 'standard', 'val')\n",
    "summary_compute = perform_varying_compute_analysis(df, [2.56e19, 5.76e23], config_compute)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "accuracy_vs_compute_figure(summary_compute, save=save_figs)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Power laws for loss "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "opt_loss_figure(summary_df, save=save_figs, bootstrap_num=0)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# this should take ~8 minutes\n",
    "opt_loss_extended_figure(summary_df, save=save_figs,bootstrap_num=200)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Hyperparameters sweep results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_sweep_opt_eta_and_bs, fit = get_interpolated_hparams_dfs(df_sweep)\n",
    "\n",
    "hparams_fit_figure(df_sweep, df_sweep_opt_eta_and_bs, fit, save=save_figs)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_sweep_beta2_095 = process_sweep_df(df.query(\"hparams=='sweep'\").query('beta2==0.95').copy())\n",
    "\n",
    "df_sweep_beta2_095_opt_eta_and_bs, fit_beta095 = get_interpolated_hparams_dfs(df_sweep_beta2_095)\n",
    "\n",
    "hparams_fit_figure(df_sweep_beta2_095, df_sweep_beta2_095_opt_eta_and_bs, fit_beta095, save=save_figs, save_path='paper/figures/hparams_fit_0.95.pdf')\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "full_sweep_figure(create_pivot_df(df_sweep), save=save_figs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tables"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_table_df = results_table(pd.concat([summary_df, summary_df_owt2]), flop_vals=FLOP_VALS, validation='all')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_table_df.loc[len(results_table_df)] = results_table(summary_df_kaplan_tuned_hparams, flop_vals=FLOP_VALS, validation='all').iloc[0]\n",
    "results_table_df.loc[-1] = [\"Kaplan Law\", \"WebText2\", \"0.88\", \"\", \"\"]\n",
    "results_table_df.loc[-2] = [\"Hoffmann Law\", \"MassiveText\", \"0.5\", \"\", \"\"]\n",
    "results_table_df.index = results_table_df.index + 2\n",
    "results_table_df = results_table_df.sort_index()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_table_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tuned_hparams(df.query('dataset==\"rw\" and hparams==\"tuned\" and warmup==\"short\"')).sort_values('Batch size')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "archs_table_df = archs_table(df.query('dataset==\"rw\" and hparams==\"tuned\" and warmup==\"short\"'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "archs_table_df"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  },
  "vscode": {
   "interpreter": {
    "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
