{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "import numpy as np\n",
    "\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "palette = sns.color_palette(\"colorblind\")\n",
    "\n",
    "from utils import compute_compute, color_rule, model2compute\n",
    "from utils import load_benchmark_results, load_steps\n",
    "from utils import process_pre_post_adjustment\n",
    "\n",
    "from plot_utils import plot_bench, plot_regressor, emergence_plots\n",
    "from regress_utils import get_features, regress_seg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dirs = {\n",
    "    'base': 'evaluations/base/',\n",
    "    'mmluaux': 'evaluations/e3/mmluaux/',\n",
    "    'gsm8kaux': 'evaluations/e3/gsm8kaux/',\n",
    "}\n",
    "\n",
    "all_results = {name: load_benchmark_results(dir_) for name, dir_ in dirs.items()}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### First figure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "process_kwargs = {'verbose': False, 'use_max': True}\n",
    "mmlu_pre, mmlu_post = process_pre_post_adjustment(\n",
    "    all_results['base']['mmlu-acc'], \n",
    "    all_results['mmluaux']['mmlu-acc'],\n",
    "    **process_kwargs,\n",
    ")\n",
    "\n",
    "gsm8k_pre, gsm8k_post = process_pre_post_adjustment(\n",
    "    all_results['base']['gsm8k'],\n",
    "    all_results['gsm8kaux']['gsm8k'],\n",
    "    **process_kwargs,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Main plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def first_plot(rule=None):\n",
    "    fig, axs = plt.subplots(2, 2, figsize=(8.6, 6), dpi=150, sharex=False, sharey=False)\n",
    "    plot_regressor(*get_features(mmlu_pre, rule=rule), 0.25, ax=axs[0, 0])\n",
    "    plot_regressor(*get_features(gsm8k_pre, rule=rule), 0., ax=axs[0, 1])\n",
    "    plot_regressor(*get_features(mmlu_post, rule=rule), 0.25, ax=axs[1, 0])\n",
    "    plot_regressor(*get_features(gsm8k_post, rule=rule), 0., ax=axs[1, 1])\n",
    "\n",
    "    mmlu_ylim = (0.23, 0.75)\n",
    "    mmlu_yticks = [0.3, 0.4, 0.5, 0.6, 0.7]\n",
    "    gsm8k_ylim = (-0.03, 0.83)\n",
    "    gsm8k_yticks = [0.0, 0.2, 0.4, 0.6, 0.8]\n",
    "    for i in range(2):\n",
    "        axs[i, 0].set_ylim(*mmlu_ylim)\n",
    "        axs[i, 0].set_yticks(mmlu_yticks)\n",
    "        axs[i, 1].set_ylim(*gsm8k_ylim)\n",
    "        axs[i, 1].set_yticks(gsm8k_yticks)\n",
    "        axs[i, 0].set_title('MMLU', fontsize=13)\n",
    "        axs[i, 1].set_title('GSM8K', fontsize=13)\n",
    "\n",
    "    for i in range(2):\n",
    "        for j in range(2):\n",
    "            axs[i, j].set_ylabel('Accuracy', fontsize=13)\n",
    "            axs[i, j].set_xlabel('Pretraining compute (FLOPs)', fontsize=13)\n",
    "            axs[i, j].tick_params(axis='both', which='both', length=0, labelsize=12)\n",
    "\n",
    "    return fig, axs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = first_plot()\n",
    "\n",
    "upper_text = 'Base models trained after November 2023 outperform those trained before November 2023'\n",
    "lower_text = 'After fine-tuning all models on the test task, differences in model performance vanish'\n",
    "fig.text(0.5, 1.005, upper_text, ha='center', va='center', fontsize=14)\n",
    "fig.text(0.5, 0.495, lower_text, ha='center', va='center', fontsize=14)\n",
    "\n",
    "legend1 = 'Models trained'\n",
    "legend2 = 'Before November 2023'\n",
    "legend3 = 'After November 2023      '\n",
    "legend_elements = [\n",
    "    plt.Line2D([0], [0], marker='', color='w', label=legend1, markerfacecolor='b', markersize=0),\n",
    "    plt.Line2D([0], [0], marker='o', color='w', label=legend2, markerfacecolor=palette[0], markersize=10),\n",
    "    plt.Line2D([0], [0], marker='o', color='w', label=legend3, markerfacecolor=palette[1], markersize=10),\n",
    "]\n",
    "\n",
    "fig.legend(handles=legend_elements, bbox_to_anchor=(0.5, -0.06),\n",
    "        loc='lower center', ncol=3, fontsize=13, frameon=True, columnspacing=0.5, handletextpad=-.2,)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.subplots_adjust(wspace=0.3, hspace=0.55)\n",
    "plt.savefig('plots/first.pdf', bbox_inches='tight', dpi=150)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Same plot, but with the alternative EN / CN split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def first_plot_compact(rule=None):\n",
    "    fig, axs = plt.subplots(2, 2, figsize=(8.6, 4.5), dpi=150, sharex=False, sharey=False)\n",
    "    plot_regressor(*get_features(mmlu_pre, rule=rule), 0.25, ax=axs[0, 0])\n",
    "    plot_regressor(*get_features(gsm8k_pre, rule=rule), 0., ax=axs[0, 1])\n",
    "    plot_regressor(*get_features(mmlu_post, rule=rule), 0.25, ax=axs[1, 0])\n",
    "    plot_regressor(*get_features(gsm8k_post, rule=rule), 0., ax=axs[1, 1])\n",
    "\n",
    "    mmlu_ylim = (0.23, 0.75)\n",
    "    mmlu_yticks = [0.3, 0.4, 0.5, 0.6, 0.7]\n",
    "    gsm8k_ylim = (-0.03, 0.83)\n",
    "    gsm8k_yticks = [0.0, 0.2, 0.4, 0.6, 0.8]\n",
    "    for i in range(2):\n",
    "        axs[i, 0].set_ylim(*mmlu_ylim)\n",
    "        axs[i, 0].set_yticks(mmlu_yticks)\n",
    "        axs[i, 1].set_ylim(*gsm8k_ylim)\n",
    "        axs[i, 1].set_yticks(gsm8k_yticks)\n",
    "\n",
    "    axs[0, 0].set_title('MMLU', fontsize=13)\n",
    "    axs[0, 1].set_title('GSM8K', fontsize=13)\n",
    "    axs[1, 0].set_title('MMLU (adjusted)', fontsize=13)\n",
    "    axs[1, 1].set_title('GSM8K (adjusted)', fontsize=13)\n",
    "\n",
    "    for i in range(2):\n",
    "        for j in range(2):\n",
    "            axs[i, j].set_ylabel('Accuracy', fontsize=13)\n",
    "            axs[-1, j].set_xlabel('Pretraining compute (FLOPs)', fontsize=13)\n",
    "            axs[i, j].tick_params(axis='both', which='both', length=0, labelsize=12)\n",
    "    \n",
    "    axs[0, 0].set_xticklabels([])\n",
    "    axs[0, 1].set_xticklabels([])\n",
    "\n",
    "    return fig, axs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "robustness_rule = lambda c: color_rule(c, use_date=False)\n",
    "\n",
    "fig, axs = first_plot_compact(robustness_rule)\n",
    "\n",
    "legend_elements = [\n",
    "    plt.Line2D([0], [0], marker='', color='w', label='Models trained', markerfacecolor='b', markersize=0),\n",
    "    plt.Line2D([0], [0], marker='o', color='w', label='Primarily on EN', markerfacecolor=palette[0], markersize=10),\n",
    "    plt.Line2D([0], [0], marker='o', color='w', label='On both EN and CN        ', markerfacecolor=palette[1], markersize=10),\n",
    "]\n",
    "\n",
    "fig.legend(handles=legend_elements, bbox_to_anchor=(0.5, -0.08),\n",
    "        loc='lower center', ncol=3, fontsize=13, frameon=True, columnspacing=0.5, handletextpad=-.2,)\n",
    "\n",
    "fig.text(0.5, 1.005, 'Model split: trained primarely on EN data', ha='center', va='center', fontsize=14)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.subplots_adjust(wspace=0.4, hspace=0.3)\n",
    "plt.savefig('plots/first-cn.pdf', bbox_inches='tight', dpi=150)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "robustness_rule = lambda c: color_rule(c, threshold='2309')\n",
    "\n",
    "fig, axs = first_plot_compact(robustness_rule)\n",
    "\n",
    "legend_elements = [\n",
    "    plt.Line2D([0], [0], marker='', color='w', label='Models trained', markerfacecolor='b', markersize=0),\n",
    "    plt.Line2D([0], [0], marker='o', color='w', label='Before September 2023', markerfacecolor=palette[0], markersize=10),\n",
    "    plt.Line2D([0], [0], marker='o', color='w', label='After September 2023        ', markerfacecolor=palette[1], markersize=10),\n",
    "]\n",
    "\n",
    "fig.legend(handles=legend_elements, bbox_to_anchor=(0.5, -0.08),\n",
    "        loc='lower center', ncol=3, fontsize=13, frameon=True, columnspacing=0.5, handletextpad=-.2,)\n",
    "\n",
    "fig.text(0.5, 1.005, 'Model split: trained before September 2023', ha='center', va='center', fontsize=14)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.subplots_adjust(wspace=0.4, hspace=0.3)\n",
    "plt.savefig('plots/rob1.pdf', bbox_inches='tight', dpi=150)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "robustness_rule = lambda c: color_rule(c, threshold='2401')\n",
    "\n",
    "fig, axs = first_plot_compact(robustness_rule)\n",
    "\n",
    "legend_elements = [\n",
    "    plt.Line2D([0], [0], marker='', color='w', label='Models trained', markerfacecolor='b', markersize=0),\n",
    "    plt.Line2D([0], [0], marker='o', color='w', label='Before January 2024', markerfacecolor=palette[0], markersize=10),\n",
    "    plt.Line2D([0], [0], marker='o', color='w', label='After January 2024        ', markerfacecolor=palette[1], markersize=10),\n",
    "]\n",
    "\n",
    "fig.legend(handles=legend_elements, bbox_to_anchor=(0.5, -0.08),\n",
    "        loc='lower center', ncol=3, fontsize=13, frameon=True, columnspacing=0.5, handletextpad=-.2,)\n",
    "\n",
    "fig.text(0.5, 1.005, 'Model split: trained before January 2024', ha='center', va='center', fontsize=14)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.subplots_adjust(wspace=0.4, hspace=0.3)\n",
    "plt.savefig('plots/rob2.pdf', bbox_inches='tight', dpi=150)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Plot the gain in accuracy after adjustment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "diff_mmlu = {m: mmlu_post[m] - mmlu_pre[m] for m in mmlu_pre}\n",
    "diff_gsm8k = {m: gsm8k_post[m] - gsm8k_pre[m] for m in gsm8k_pre}\n",
    "\n",
    "fig, ax = plt.subplots(1, 2, figsize=(7.5, 1.8), dpi=150, sharex=True)\n",
    "\n",
    "xticks = [1e20, 1e21, 1e22, 1e23, 1e24]\n",
    "\n",
    "msize = 50\n",
    "plot_bench(ax[0], diff_mmlu, title='MMLU', xticks=xticks, msize=msize, title_fontsize=12, color_code='date')\n",
    "plot_bench(ax[1], diff_gsm8k, title='GSM8K', xticks=xticks, msize=msize, title_fontsize=12, color_code='date')\n",
    "\n",
    "ax[0].set_ylabel('Gain in accuracy\\nafter adjustment', fontsize=12)\n",
    "for i in range(2):\n",
    "    ax[i].set_xlabel('Pretraining compute (FLOPs)', fontsize=12)\n",
    "    ax[i].tick_params(axis='both', which='both', length=0, labelsize=12)\n",
    "\n",
    "plt.subplots_adjust(wspace=0.2)\n",
    "\n",
    "legend_elements = [\n",
    "    plt.Line2D([0], [0], marker='', color='w', label=legend1, markerfacecolor='b', markersize=0),\n",
    "    plt.Line2D([0], [0], marker='o', color='w', label=legend2, markerfacecolor=palette[0], markersize=10),\n",
    "    plt.Line2D([0], [0], marker='o', color='w', label=legend3, markerfacecolor=palette[1], markersize=10),\n",
    "]\n",
    "\n",
    "fig.legend(handles=legend_elements, bbox_to_anchor=(0.5, -0.4),\n",
    "        loc='lower center', ncol=3, fontsize=12, frameon=True, columnspacing=0.5, handletextpad=.2,)\n",
    "\n",
    "# save as pdf on plots/\n",
    "plt.savefig('plots/gain_benchmarks.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Section 3: recreating the differences observed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get the old models\n",
    "models = [m for m in model2compute if color_rule(m) == palette[0]]\n",
    "mmlu = {m: a for m, a in all_results['base']['mmlu-acc'].items() if m in models}\n",
    "gsm8k = {m: a for m, a in all_results['base']['gsm8k'].items() if m in models}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_mmlu_gsm8k(base_dir):\n",
    "    mmlu = load_benchmark_results(base_dir + '/mmluaux/')['mmlu-acc']\n",
    "    gsm8k = load_benchmark_results(base_dir + '/gsm8kaux/')['gsm8k']\n",
    "    return mmlu, gsm8k\n",
    "\n",
    "mmlu_ft, gsm8k_ft = load_mmlu_gsm8k('evaluations/e1')\n",
    "mmlu_adj, gsm8k_adj = load_mmlu_gsm8k('evaluations/e2')\n",
    "# mmlu_adj, gsm8k_adj = load_mmlu_gsm8k('evaluations/e3')\n",
    "mmlu_ft_adj, gsm8k_ft_adj = load_mmlu_gsm8k('evaluations/e1+2')\n",
    "# mmlu_ft_adj, gsm8k_ft_adj = load_mmlu_gsm8k('evaluations/e1+3')\n",
    "\n",
    "mmlu, mmlu_ft = process_pre_post_adjustment(mmlu, mmlu_ft, **process_kwargs)\n",
    "gsm8k, gsm8k_ft = process_pre_post_adjustment(gsm8k, gsm8k_ft, **process_kwargs)\n",
    "\n",
    "mmlu_ft, mmlu_ft_adj = process_pre_post_adjustment(mmlu_ft, mmlu_ft_adj, **process_kwargs)\n",
    "gsm8k_ft, gsm8k_ft_adj = process_pre_post_adjustment(gsm8k_ft, gsm8k_ft_adj, **process_kwargs)\n",
    "\n",
    "mmlu, mmlu_adj = process_pre_post_adjustment(mmlu, mmlu_adj, **process_kwargs)\n",
    "gsm8k, gsm8k_adj = process_pre_post_adjustment(gsm8k, gsm8k_adj, **process_kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "models = set(model2compute.keys())\n",
    "models &= set(mmlu.keys()) & set(gsm8k.keys())\n",
    "models = sorted([m for m in models if color_rule(m) == palette[0]])\n",
    "\n",
    "too_small = [m for m in models if compute_compute(m) <= compute_compute('qwen-1.5-0.5b')]\n",
    "others = [m for m in models if m not in too_small]\n",
    "models1 = others + too_small\n",
    "models2 = others\n",
    "\n",
    "mmlu_ = {k: v for k, v in mmlu.items() if k in models2}\n",
    "mmlu = {k: v for k, v in mmlu.items() if k in models1}\n",
    "mmlu_ft = {k: v for k, v in mmlu_ft.items() if k in models2}\n",
    "mmlu_adj = {k: v for k, v in mmlu_adj.items() if k in models1}\n",
    "mmlu_ft_adj = {k: v for k, v in mmlu_ft_adj.items() if k in models2}\n",
    "\n",
    "gsm8k_ = {k: v for k, v in gsm8k.items() if k in models2}\n",
    "gsm8k = {k: v for k, v in gsm8k.items() if k in models1}\n",
    "gsm8k_ft = {k: v for k, v in gsm8k_ft.items() if k in models2}\n",
    "gsm8k_adj = {k: v for k, v in gsm8k_adj.items() if k in models1}\n",
    "gsm8k_ft_adj = {k: v for k, v in gsm8k_ft_adj.items() if k in models2}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_ax(ax, blue_up, orange_up, blue_down, orange_down, r):\n",
    "    def combine_plot(ax, r1, r2, split=True):\n",
    "        c1, _, a1 = get_features(r1)\n",
    "        c2, _, a2 = get_features(r2)\n",
    "        if split:\n",
    "            f = [0]*len(c1) + [1]*len(c2)\n",
    "        else:\n",
    "            f = [0] * (len(c1) + len(c2))\n",
    "        plot_regressor(c1 + c2, f, a1 + a2, r, ax=ax, plot_lines=split, main=False)\n",
    "\n",
    "    combine_plot(ax[0], blue_up, orange_up)\n",
    "    combine_plot(ax[1], blue_down, orange_down)\n",
    "\n",
    "    # set all y to the same\n",
    "    max_ = max(max(blue_up.values()), max(orange_up.values()), max(blue_down.values()), max(orange_down.values()))\n",
    "    min_ = min(min(blue_up.values()), min(orange_up.values()), min(blue_down.values()), min(orange_down.values()))\n",
    "\n",
    "    ax[0].set_ylim(min_ - 0.03, max_+0.03)\n",
    "    for i in range(2):\n",
    "        ax[i].grid(alpha=0.5)\n",
    "\n",
    "fig, axs = plt.subplots(2, 2, figsize=(8.5, 3.7), sharex=True, sharey=False)\n",
    "\n",
    "kwargs = {'msize': 40, 'alpha': 0.7}\n",
    "thetas = [[0.001, 0.103, -0.011], [-0.011, 0.181, -0.011]]\n",
    "for j, ax in enumerate(axs):\n",
    "    if j == 0:\n",
    "        plot_ax([axs[0, 0], axs[1, 0]], mmlu, mmlu_ft, mmlu_adj, mmlu_ft_adj, 0.25)\n",
    "    else:\n",
    "        plot_ax([axs[0, 1], axs[1, 1]], gsm8k, gsm8k_ft, gsm8k_adj, gsm8k_ft_adj, 0)\n",
    "\n",
    "    # y_suffix = '' if j == 0 else '(adjusted)'\n",
    "    ax[0].set_ylabel('Accuracy', fontsize=12)\n",
    "    ax[1].set_ylabel('Accuracy', fontsize=12)\n",
    "    yticks = [0.3, 0.4, 0.5, 0.6] if j == 0 else [0., 0.2, 0.4, 0.6]\n",
    "    for i in range(2):\n",
    "        axs[i, j].set_yticks(yticks)\n",
    "\n",
    "    # no ticks on y axis for the second row\n",
    "    # for i in range(1, 2):\n",
    "    #     ax[i].set_yticklabels([])\n",
    "\n",
    "for i in range(2):\n",
    "    for j in range(2):\n",
    "        ax = axs[i, j]\n",
    "        ax.tick_params(axis='both', which='both', length=0, labelsize=10)\n",
    "\n",
    "# axs[0, 0].set_title('No model fine-tuned\\n', fontsize=12)\n",
    "# axs[0, 0].set_title('Without adjustment', fontsize=12)\n",
    "# axs[0, 1].set_title('Some models fine-tuned\\n(adjusted)', fontsize=12)\n",
    "\n",
    "axs[0, 0].set_title('MMLU', fontsize=12)\n",
    "axs[1, 0].set_title('MMLU (adjusted)', fontsize=12)\n",
    "axs[0, 1].set_title('GSM8K', fontsize=12)\n",
    "axs[1, 1].set_title('GSM8K (adjusted)', fontsize=12)\n",
    "\n",
    "\n",
    "for i in range(2):\n",
    "    axs[1, i].set_xlabel('Pretraining compute (FLOPs)', fontsize=12)\n",
    "\n",
    "xticks = [1e20, 1e21, 1e22, 1e23]\n",
    "axs[0, 0].set_xticks(xticks)\n",
    "\n",
    "# upper_text = 'No adjustment of the benchmark scores'\n",
    "# lower_text = 'Benchmark scores adjusted by fine-tuning all models on the test task'\n",
    "# fig.text(0.5, 0.955, upper_text, ha='center', va='center', fontsize=11)\n",
    "# fig.text(0.5, 0.495, lower_text, ha='center', va='center', fontsize=11)\n",
    "\n",
    "# legend1 = 'Models released before November 2023'\n",
    "legend2 = 'Without fine-tuning'\n",
    "legend3 = 'After fine-tuning on the test task'\n",
    "legend_elements = [\n",
    "    plt.Line2D([0], [0], marker='o', color='w', label=legend2, markerfacecolor=palette[0], markersize=8),\n",
    "    plt.Line2D([0], [0], marker='o', color='w', label=legend3, markerfacecolor=palette[1], markersize=8),\n",
    "]\n",
    "\n",
    "# fig.text(0.5, -0.09, 'Models released before November 2023', ha='center', va='center', fontsize=11)\n",
    "# fig.legend(handles=legend_elements, bbox_to_anchor=(0.5, -0.21),\n",
    "#         loc='lower center', ncol=3, fontsize=11, frameon=False, columnspacing=0.5, handletextpad=-.2,)\n",
    "\n",
    "plt.subplots_adjust(wspace=0.35, hspace=0.4)\n",
    "\n",
    "plt.savefig('plots/ft.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "benchmarks = {\n",
    "    'ARC\\nAccuracy': 'arc_challenge',\n",
    "    'HellaSwag\\nAccuracy': 'hellaswag',\n",
    "}\n",
    "\n",
    "fig, axs = plt.subplots(2, 3, figsize=(9.5, 3), sharex=True, sharey=True)\n",
    "\n",
    "kwargs = {'msize': 40, 'alpha': 0.7}\n",
    "thetas = [[0.001, 0.120, -0.014], [-0.012, 0.114, 0.009]]\n",
    "for j, (ax, (name, benchmark)) in enumerate(zip(axs, benchmarks.items())):\n",
    "    not_mc = all_results['base'][benchmark + '-acc']\n",
    "    mc = all_results['base'][benchmark + '_mc-acc']\n",
    "    mc_adj = all_results['mmluaux'][benchmark + '_mc-acc']\n",
    "    mc, mc_adj = process_pre_post_adjustment(mc, mc_adj, **process_kwargs)\n",
    "\n",
    "    plot_regressor(*get_features(not_mc), 0.25, ax=ax[0], main=False)\n",
    "    plot_regressor(*get_features(mc), 0.25, ax=ax[1], main=False)\n",
    "    plot_regressor(*get_features(mc_adj), 0.25, ax=ax[2], main=False)\n",
    "\n",
    "    ax[0].set_ylabel(name, fontsize=12)\n",
    "\n",
    "    # add legend\n",
    "    for i, theta in enumerate(thetas[j]):\n",
    "        to_add = format(theta, \".3f\")\n",
    "        if i == 1:\n",
    "            to_add = '\\\\mathbf{' + to_add + '}'\n",
    "        label = \"\\\\hat{\\\\theta} = \" + to_add\n",
    "        label = '$' + label + '$'\n",
    "        legend_elements = [\n",
    "            plt.Line2D([0], [0], marker='o', color='w', label=label, markerfacecolor='black', markersize=0),\n",
    "        ]\n",
    "        ax[i].legend(handles=legend_elements, loc='upper left', fontsize=12, handletextpad=-2)\n",
    "\n",
    "for i in range(2):\n",
    "    for j in range(3):\n",
    "        ax = axs[i, j]\n",
    "        ax.tick_params(axis='both', which='both', length=0, labelsize=10)\n",
    "\n",
    "axs[0, 0].set_title('Cloze evaluation\\n', fontsize=12)\n",
    "axs[0, 1].set_title('Multiple choice\\n', fontsize=12)\n",
    "axs[0, 2].set_title('Multiple choice\\n(adjusted)', fontsize=12)\n",
    "\n",
    "for i in range(3):\n",
    "    axs[-1, i].set_xlabel('Pretraining compute (FLOPs)', fontsize=12)\n",
    "\n",
    "xticks = [1e20, 1e21, 1e22, 1e23, 1e24]\n",
    "axs[0, 0].set_xticks(xticks)\n",
    "\n",
    "plt.subplots_adjust(wspace=0.12, hspace=0.05)\n",
    "plt.savefig('plots/reformulate-adjusted.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "benchmarks = {\n",
    "    'ARC': 'arc_challenge',\n",
    "    'HellaSwag\\nAccuracy': 'hellaswag',\n",
    "}\n",
    "\n",
    "fig, axs = plt.subplots(1, 3, figsize=(9.5, 1.5), sharex=True, sharey=False)\n",
    "\n",
    "kwargs = {'msize': 40, 'alpha': 0.7}\n",
    "\n",
    "# make brier negative\n",
    "all_results['base']['mmlu-brier'] = {m: -abs(v) for m, v in all_results['base']['mmlu-brier'].items()}\n",
    "\n",
    "plot_regressor(*get_features(mmlu_pre), 0.25, ax=axs[0], main=False)\n",
    "plot_regressor(*get_features(all_results['base']['mmlu_cloze-acc']), 0.25, ax=axs[1], main=False)\n",
    "plot_regressor(*get_features(all_results['base']['mmlu-brier']), -0.75, ax=axs[2], plot_lines=False, main=False)\n",
    "\n",
    "axs[0].set_ylabel('Accuracy', fontsize=12)\n",
    "axs[2].set_ylabel('-Brier score', fontsize=12)\n",
    "axs[1].set_ylabel('Accuracy', fontsize=12)\n",
    "\n",
    "# for i in range(2):\n",
    "for j in range(3):\n",
    "    ax = axs[j]\n",
    "    ax.tick_params(axis='both', which='both', length=0, labelsize=10)\n",
    "\n",
    "axs[0].set_title('MMLU - Multiple choice', fontsize=12)\n",
    "axs[2].set_title('MMLU - Multiple choice', fontsize=12)\n",
    "axs[1].set_title('MMLU - Cloze', fontsize=12)\n",
    "\n",
    "for i in range(3):\n",
    "    axs[i].set_xlabel('Pretraining compute (FLOPs)', fontsize=12)\n",
    "\n",
    "xticks = [1e20, 1e21, 1e22, 1e23, 1e24]\n",
    "axs[0].set_xticks(xticks)\n",
    "\n",
    "plt.subplots_adjust(wspace=0.4, hspace=0.05)\n",
    "plt.savefig('plots/mmlu-emergence.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "benchmarks = {\n",
    "    'ARC\\nAccuracy': 'arc_challenge',\n",
    "    'HellaSwag\\nAccuracy': 'hellaswag',\n",
    "}\n",
    "\n",
    "fig, axs = plt.subplots(2, 3, figsize=(9.5, 2.5), sharex=True, sharey=False)\n",
    "\n",
    "results_diff = {}\n",
    "kwargs = {'msize': 40, 'alpha': 0.7}\n",
    "thetas = [[0.001, 0.120, -0.014], [-0.012, 0.114, 0.009]]\n",
    "for j, (ax, (name, benchmark)) in enumerate(zip(axs, benchmarks.items())):\n",
    "    plot_regressor(*get_features(all_results['base'][benchmark + '-acc']), 0.25, ax=ax[0], plot_lines=False, main=False)\n",
    "    plot_regressor(*get_features(all_results['base'][benchmark + '_mc-acc']), 0.25, ax=ax[1], plot_lines=False, main=False)\n",
    "\n",
    "    # change brier to negative\n",
    "    neg_brier = {m: -v for m, v in all_results['base'][benchmark + '_mc-brier'].items()}\n",
    "    plot_regressor(*get_features(neg_brier), 0.25, ax=ax[2], plot_lines=False, main=False)\n",
    "\n",
    "    ax[0].set_ylabel(name, fontsize=12)\n",
    "    ax[1].set_ylabel('Accuracy', fontsize=12)\n",
    "    ax[2].set_ylabel('-Brier score', fontsize=12)\n",
    "\n",
    "\n",
    "for i in range(2):\n",
    "    for j in range(3):\n",
    "        ax = axs[i, j]\n",
    "        ax.tick_params(axis='both', which='both', length=0, labelsize=10)\n",
    "\n",
    "axs[0, 0].set_title('Cloze', fontsize=12)\n",
    "axs[0, 1].set_title('Multiple choice', fontsize=12)\n",
    "axs[0, 2].set_title('Multiple choice', fontsize=12)\n",
    "\n",
    "for i in range(3):\n",
    "    axs[-1, i].set_xlabel('Pretraining compute', fontsize=12)\n",
    "\n",
    "xticks = [1e20, 1e21, 1e22, 1e23, 1e24]\n",
    "axs[0, 0].set_xticks(xticks)\n",
    "ax.set_xscale('log')\n",
    "\n",
    "plt.subplots_adjust(wspace=0.4, hspace=0.05)\n",
    "plt.savefig('plots/brier-arc.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Section 4"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compute the gain in pareto area"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pareto_frontier(results):\n",
    "    \"\"\" constructs the pareto frontier of performance against compute \"\"\"\n",
    "    # sort models by compute\n",
    "    compute = {m: compute_compute(m) for m in results}\n",
    "    models = sorted(results, key=lambda m: compute[m])\n",
    "\n",
    "    pareto_result = [results[models[0]]]\n",
    "    pareto_compute = [compute[models[0]]]\n",
    "    for m in models[1:]:\n",
    "        if results[m] > pareto_result[-1]:\n",
    "            pareto_result.append(results[m])\n",
    "            pareto_compute.append(compute[m])\n",
    "    \n",
    "    pareto_compute.append(compute[models[-1]])\n",
    "    pareto_result.append(pareto_result[-1])\n",
    "    return pareto_compute, pareto_result\n",
    "\n",
    "def get_pareto_plot_points(pareto):\n",
    "    c, a = pareto\n",
    "    for i in range(len(c) - 1):\n",
    "        yield c[i], a[i]\n",
    "        yield c[i+1], a[i]\n",
    "    yield c[-1], a[-1]\n",
    "\n",
    "def get_new_pareto(pareto, new_c):\n",
    "    pareto_c, pareto_a = pareto\n",
    "    i = 0\n",
    "    for c in new_c:\n",
    "        while i + 1 < len(pareto_c) and pareto_c[i+1] <= c:\n",
    "            i += 1\n",
    "        yield c, pareto_a[i]\n",
    "\n",
    "def integrate(pareto1, pareto2):\n",
    "    c1, a1 = pareto1\n",
    "    c2, a2 = pareto2\n",
    "\n",
    "    assert c1 == c2\n",
    "    c = np.log10(c1)\n",
    "    da = (np.array(a2) - np.array(a1))[:-1]\n",
    "    total = np.sum(da * np.diff(c))\n",
    "    return total\n",
    "\n",
    "\n",
    "def plot_pareto(results, ax=None, ylabel=None, xlabel=None, title=None):\n",
    "    if ax is None:\n",
    "        _, ax = plt.subplots(1, 1, figsize=(4.5, 2), dpi=150)\n",
    "\n",
    "    results_old = {m: v for m, v in results.items() if color_rule(m) == palette[0]}\n",
    "\n",
    "    pareto_pre = pareto_frontier(results_old)\n",
    "    pareto_post = pareto_frontier(results)\n",
    "\n",
    "    union_compute = sorted(list(set(pareto_pre[0]) | set(pareto_post[0])))\n",
    "    pareto_pre = list(zip(*get_new_pareto(pareto_pre, union_compute)))\n",
    "    pareto_post = list(zip(*get_new_pareto(pareto_post, union_compute)))\n",
    "\n",
    "    plot_pareto_pre = list(zip(*get_pareto_plot_points(pareto_pre)))\n",
    "    plot_pareto_post = list(zip(*get_pareto_plot_points(pareto_post)))\n",
    "\n",
    "    ax.plot(*plot_pareto_pre, color=palette[0], alpha=1., linewidth=3)\n",
    "    ax.plot(*plot_pareto_post, color=palette[1], alpha=1., linewidth=3)\n",
    "\n",
    "    # shade the area between the two curves\n",
    "    area = integrate(pareto_pre, pareto_post)\n",
    "    ax.fill_between(plot_pareto_pre[0], plot_pareto_pre[1], plot_pareto_post[1], color='g', alpha=0.15, label=f'{area:.2f}')\n",
    "\n",
    "    ax.set_xscale('log')\n",
    "\n",
    "    if ylabel is not None:\n",
    "        ax.set_ylabel(ylabel, fontsize=12)\n",
    "    if xlabel is not None:\n",
    "        ax.set_xlabel(xlabel, fontsize=12)\n",
    "    if title is not None:\n",
    "        ax.set_title(title, fontsize=12)\n",
    "\n",
    "    # legend below the plot\n",
    "    ax.legend(fontsize=11, frameon=False, loc='upper left', handletextpad=0.2, bbox_to_anchor=(-0.05, 1.05))\n",
    "\n",
    "    ax.xaxis.set_tick_params(which='both', length=0)\n",
    "    ax.yaxis.set_tick_params(which='both', length=0)\n",
    "\n",
    "    ax.grid(alpha=0.5)\n",
    "\n",
    "    return area\n",
    "\n",
    "fig = plt.figure(figsize=(8., 1.4))\n",
    "\n",
    "gs = fig.add_gridspec(1, 4, wspace=0.5)\n",
    "axs = [fig.add_subplot(gs[0, i]) for i in range(4)]\n",
    "pos = [ax.get_position() for ax in axs]\n",
    "\n",
    "width = 0.2\n",
    "dplot = 0.08\n",
    "dw = 0.10\n",
    "axs[0].set_position([pos[0].x0, pos[0].y0, width, pos[0].height])\n",
    "axs[1].set_position([pos[0].x0 + pos[0].width + dplot, pos[1].y0, width, pos[1].height])\n",
    "axs[2].set_position([pos[2].x0 + dw, pos[2].y0, width, pos[3].height])\n",
    "axs[3].set_position([pos[2].x0 + dw + pos[2].width + dplot, pos[3].y0, width, pos[3].height])\n",
    "\n",
    "plot_pareto(mmlu_pre, ax=axs[0])\n",
    "plot_pareto(mmlu_post, ax=axs[1])\n",
    "\n",
    "plot_pareto(gsm8k_pre, ax=axs[2])\n",
    "plot_pareto(gsm8k_post, ax=axs[3])\n",
    "\n",
    "axs[0].set_yticks([0.3, 0.5, 0.7])\n",
    "\n",
    "for ax in axs:\n",
    "    ax.set_xlim(3e21, None)\n",
    "\n",
    "axs[1].set_yticklabels([])\n",
    "axs[3].set_yticklabels([])\n",
    "\n",
    "mmlu_ylim = (0.24, 0.75)\n",
    "gsm8k_ylim = (-0.02, 0.8)\n",
    "axs[0].set_ylim(*mmlu_ylim)\n",
    "axs[1].set_ylim(*mmlu_ylim)\n",
    "axs[2].set_ylim(*gsm8k_ylim)\n",
    "axs[3].set_ylim(*gsm8k_ylim)\n",
    "\n",
    "axs[0].set_ylabel('Accuracy', fontsize=12)\n",
    "axs[2].set_ylabel('Accuracy', fontsize=12)\n",
    "# axs[0].set_xlabel('Compute', fontsize=12)\n",
    "# axs[1].set_xlabel('Compute', fontsize=12)\n",
    "# axs[2].set_xlabel('Compute', fontsize=12)\n",
    "# axs[3].set_xlabel('Compute', fontsize=12)\n",
    "xshift = 1.08\n",
    "axs[0].set_xlabel('Pretraining compute (FLOPs)', fontsize=12, x=xshift)\n",
    "axs[2].set_xlabel('Pretraining compute (FLOPs)', fontsize=12, x=xshift)\n",
    "axs[0].set_title('Unadjusted', fontsize=12)\n",
    "axs[1].set_title('Adjusted', fontsize=12)\n",
    "axs[2].set_title('Unadjusted', fontsize=12)\n",
    "axs[3].set_title('Adjusted', fontsize=12)\n",
    "\n",
    "fig.text(0.35, 1.09, 'MMLU', ha='center', va='center', fontsize=13)\n",
    "fig.text(0.88, 1.09, 'GSM8K', ha='center', va='center', fontsize=13)\n",
    "\n",
    "legend_elements = [\n",
    "    plt.Line2D([0], [0], color='w', label='Pareto front of ', markerfacecolor='b', markersize=0),\n",
    "    plt.Line2D([0], [0], color=palette[0], label='Models trained before November 2023', linewidth=4),\n",
    "    plt.Line2D([0], [0], color=palette[1], label='All models', linewidth=4),\n",
    "]\n",
    "\n",
    "legend_position = (0.57,  -0.54)\n",
    "fig.legend(handles=legend_elements, bbox_to_anchor=legend_position,\n",
    "        loc='lower center', ncol=3, fontsize=12, frameon=True, columnspacing=1.2, handletextpad=.9, handlelength=1.2)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('plots/pareto_mmlu.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compare Pythia, Llama 2 and Qwen"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(8., 1.4))\n",
    "\n",
    "gs = fig.add_gridspec(1, 4, wspace=0.5)\n",
    "axs = [fig.add_subplot(gs[0, i]) for i in range(4)]\n",
    "pos = [ax.get_position() for ax in axs]\n",
    "\n",
    "width = 0.2\n",
    "dplot = 0.08\n",
    "dw = 0.11\n",
    "axs[0].set_position([pos[0].x0, pos[0].y0, width, pos[0].height])\n",
    "axs[1].set_position([pos[0].x0 + pos[0].width + dplot, pos[1].y0, width, pos[1].height])\n",
    "axs[2].set_position([pos[2].x0 + dw, pos[2].y0, width, pos[3].height])\n",
    "axs[3].set_position([pos[2].x0 + dw + pos[2].width + dplot, pos[3].y0, width, pos[3].height])\n",
    "\n",
    "def plot_llama_qwen(result, ax=None):\n",
    "    if ax is None:\n",
    "        fig, ax = plt.subplots(1, 1, figsize=(6, 4), dpi=150)\n",
    "    qwen = {m: result[m] for m in result if 'qwen' in m}\n",
    "    llama = {m: result[m] for m in result if 'llama-2' in m and 'open' not in m and 'llama-3' not in m}\n",
    "    pythia = {m: result[m] for m in result if 'pythia' in m and 'b' in m}\n",
    "\n",
    "    kwargs = {'alpha': 0.85, 's': 50, 'zorder': 1}\n",
    "    ax.scatter([compute_compute(m) for m in qwen], [qwen[m] for m in qwen],c=palette[1], label='Qwen', **kwargs)\n",
    "    ax.scatter([compute_compute(m) for m in llama], [llama[m] for m in llama], label='Llama 2', c=[palette[0]], **kwargs)\n",
    "    ax.scatter([compute_compute(m) for m in pythia], [pythia[m] for m in pythia], c=palette[2], label='Pythia', **kwargs)\n",
    "\n",
    "    ax.set_xscale('log')\n",
    "    ax.grid(zorder=-10)\n",
    "    ax.set_axisbelow(True)\n",
    "\n",
    "    ax.xaxis.set_tick_params(which='both', length=0)\n",
    "    ax.yaxis.set_tick_params(which='both', length=0)\n",
    "\n",
    "# plot stuff here\n",
    "plot_llama_qwen(mmlu_pre, ax=axs[0])\n",
    "plot_llama_qwen(mmlu_post, ax=axs[1])\n",
    "plot_llama_qwen(gsm8k_pre, ax=axs[2])\n",
    "plot_llama_qwen(gsm8k_post, ax=axs[3])\n",
    "\n",
    "\n",
    "axs[0].set_yticks([0.3, 0.5, 0.7])\n",
    "axs[1].set_yticks([0.3, 0.5, 0.7])\n",
    "\n",
    "axs[2].set_yticks([0., 0.4, 0.8])\n",
    "axs[3].set_yticks([0., 0.4, 0.8])\n",
    "\n",
    "axs[1].set_yticklabels([])\n",
    "axs[3].set_yticklabels([])\n",
    "\n",
    "mmlu_ylim = (0.24, 0.76)\n",
    "gsm8k_ylim = (-0.02, 0.84)\n",
    "axs[0].set_ylim(*mmlu_ylim)\n",
    "axs[1].set_ylim(*mmlu_ylim)\n",
    "axs[2].set_ylim(*gsm8k_ylim)\n",
    "axs[3].set_ylim(*gsm8k_ylim)\n",
    "\n",
    "axs[0].set_ylabel('Accuracy', fontsize=12)\n",
    "axs[2].set_ylabel('Accuracy', fontsize=12)\n",
    "# for i in range(4):\n",
    "#     axs[i].set_xlabel('Pretraining compute', fontsize=12)\n",
    "\n",
    "# shift to the right\n",
    "xshift = 1.08\n",
    "axs[0].set_xlabel('Pretraining compute (FLOPs)', fontsize=12, x=xshift)\n",
    "axs[2].set_xlabel('Pretraining compute (FLOPs)', fontsize=12, x=xshift)\n",
    "\n",
    "axs[0].set_title('Unadjusted', fontsize=12)\n",
    "axs[1].set_title('Adjusted', fontsize=12)\n",
    "axs[2].set_title('Unadjusted', fontsize=12)\n",
    "axs[3].set_title('Adjusted', fontsize=12)\n",
    "\n",
    "# xlim up to 1e24\n",
    "for ax in axs:\n",
    "    ax.set_xlim(None, 1.5e24)\n",
    "\n",
    "fig.text(0.35, 1.09, 'MMLU', ha='center', va='center', fontsize=13)\n",
    "fig.text(0.88, 1.09, 'GSM8K', ha='center', va='center', fontsize=13)\n",
    "\n",
    "legend_elements = [\n",
    "    matplotlib.lines.Line2D([0], [0], color=palette[0], label='Llama 2', markersize=8, marker='o', markerfacecolor=palette[0], ls=''),\n",
    "    matplotlib.lines.Line2D([0], [0], color=palette[1], label='Qwen 1.5', markersize=8, marker='o', markerfacecolor=palette[1], ls=''),\n",
    "    matplotlib.lines.Line2D([0], [0], color=palette[2], label='Pythia', markersize=8, marker='o', markerfacecolor=palette[2], ls='')\n",
    "]\n",
    "\n",
    "legend_position = (0.57,  -0.56)\n",
    "ncol = 3\n",
    "fig.legend(handles=legend_elements, bbox_to_anchor=legend_position,\n",
    "        loc='lower center', ncol=ncol, fontsize=12, frameon=True, columnspacing=1.5, handletextpad=.9, handleheight=0.8, handlelength=1.2)\n",
    "\n",
    "plt.savefig('plots/families.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Section 5: emergence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_mmlu = load_steps('evaluations/intermediate/mmluaux/evals/', 'mmlu')\n",
    "results_mmlu = {s: {m: results_mmlu[s][m]['acc'] for m in results_mmlu[s]} for s in results_mmlu}  # use acc\n",
    "results_gsm8k = load_steps('evaluations/intermediate/gsm8kaux/evals/', 'gsm8k')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pwlf\n",
    "import statsmodels.api as sm\n",
    "\n",
    "fit_width = 5.5\n",
    "fit_color = palette[1]\n",
    "\n",
    "def plot_r2(ax, results, ylabel='Accuracy', title=None, ylim=None, xlim=None, msize=50, x=True, titles=True, **kwargs):\n",
    "    models = list(results.keys())\n",
    "    models = [m for m in models if 'neo' not in m]\n",
    "    per1 = [compute_compute(model) for model in models]\n",
    "    per2 = [results[model] for model in models]\n",
    "\n",
    "    ax.scatter(per1, per2, color=palette[0], alpha=0.5, s=msize)\n",
    "    # if x:\n",
    "    #     ax.set_xlabel('Pretraining compute', fontsize=12)\n",
    "    ax.set_ylabel(ylabel, fontsize=12)\n",
    "    ax.grid(alpha=0.5)\n",
    "    \n",
    "    ax.set_xscale('log')\n",
    "\n",
    "    if title is not None and titles:\n",
    "        ax.set_title(title, fontsize=10.5)\n",
    "    if ylim is not None:\n",
    "        ax.set_ylim(*ylim)\n",
    "    if xlim is not None:\n",
    "        ax.set_xlim(*xlim)\n",
    "    \n",
    "    mod = sm.OLS(per2,sm.add_constant(np.log(per1)))\n",
    "    fii = mod.fit()\n",
    "    ax.legend([\"$R^2=$\"+str(fii.summary2().tables[0][1][6])],fontsize=11,handlelength=0, handletextpad=0,markerscale=0,loc=\"upper left\", frameon=True)\n",
    "\n",
    "    ax.plot(per1,fii.summary2().tables[1]['Coef.'][-1]*np.log(per1)+fii.summary2().tables[1]['Coef.'][0],color=fit_color,alpha=0.8, linewidth=fit_width)\n",
    "    ax.set_xticks([1e21, 1e22, 1e23, 1e24])\n",
    "    ax.set_xlim(1e20, None)\n",
    "    ax.tick_params(axis='both', which='both', length=0, labelsize=10)\n",
    "    \n",
    "def plot_emergence(ax, results, ylabel='Accuracy', title=None, ylim=None, xlim=None, msize=50, yconst=0.25, x=True, titles=True):\n",
    "    models = list(results.keys())\n",
    "    models = [m for m in models if 'neo' not in m]\n",
    "    per1 = [compute_compute(model) for model in models]\n",
    "    per2 = [results[model] for model in models]\n",
    "\n",
    "    ax.scatter(per1, per2, color=palette[0], alpha=0.5, s=msize)\n",
    "    # if x:\n",
    "    #     ax.set_xlabel('Pretraing compute', fontsize=12)\n",
    "\n",
    "    ax.set_ylabel(ylabel, fontsize=12)\n",
    "    ax.grid(alpha=0.5)\n",
    "    \n",
    "    ax.set_xscale('log')\n",
    "\n",
    "    if title is not None and titles:\n",
    "        ax.set_title(title, fontsize=10.5)\n",
    "    if ylim is not None:\n",
    "        ax.set_ylim(*ylim)\n",
    "    if xlim is not None:\n",
    "        ax.set_xlim(*xlim)\n",
    "    \n",
    "    per1 = np.array(per1)\n",
    "    per2 = np.array(per2)\n",
    "\n",
    "    mask = (per1 > 1e20) & (per1 < 1e24)\n",
    "    per1 = per1[mask]\n",
    "    per2 = per2[mask]\n",
    "\n",
    "    x = np.log(per1)\n",
    "    y = per2\n",
    "\n",
    "\n",
    "    xc = np.array([0.8 * 1e20, 0.7 * 1e20])\n",
    "    xc = np.log(xc)\n",
    "    yc = [yconst, yconst]\n",
    "\n",
    "    # Fit the data with two line segments\n",
    "    my_pwlf = pwlf.PiecewiseLinFit(x, y)\n",
    "    res = my_pwlf.fit(2, xc, yc)\n",
    "\n",
    "    xx = np.linspace(min(x), max(x), 100)\n",
    "    yy = my_pwlf.predict(xx)\n",
    "\n",
    "    ax.plot(np.exp(xx), yy, '-', c=fit_color, alpha=0.8, linewidth=fit_width)\n",
    "    \n",
    "    emg = np.exp(res[1])\n",
    "    if emg > 1e23:\n",
    "        emg = 2.1 * 1e22\n",
    "\n",
    "    ax.legend([f\"$c_e$: {emg:.1e}\"],fontsize=11,handlelength=0, handletextpad=0,markerscale=0,loc=\"upper left\", frameon=True)\n",
    "    ax.set_xticks([1e21, 1e22, 1e23])\n",
    "    ax.tick_params(axis='both', which='both', length=0, labelsize=10)\n",
    "\n",
    "    return res\n",
    "\n",
    "def emergence_plots(data, base_data, steps, yconst, suptitle=None, emergence=True, ylabel='Accuracy', titles=True, x=True, axs=None, is_mmlu=False):\n",
    "    plot_f = plot_emergence if emergence else plot_r2\n",
    "\n",
    "    data[0] = base_data\n",
    "    \n",
    "    if axs is None:\n",
    "        _, axs = plt.subplots(1, len(steps), figsize=(9.5,1.5), dpi=200, sharex=True, sharey=True)\n",
    "\n",
    "    for i, step in enumerate(steps):\n",
    "        examples = step * 64  # batch size\n",
    "        to_plot = data[step]\n",
    "        to_plot = {m: to_plot[m] for m in to_plot if color_rule(m) == palette[0]}\n",
    "\n",
    "        examples_k = '0'\n",
    "        if examples > 0:\n",
    "            if len(str(examples)) > 4:\n",
    "                examples_k = str(int(examples/1000)) + 'k'\n",
    "            else:\n",
    "                 examples_k = str(examples/1000) + 'k'\n",
    "\n",
    "        title = f\"Task examples: {examples_k}\" if not x else None\n",
    "        plot_f(axs[i], to_plot, title=title, ylabel='', yconst=yconst, x=x, titles=titles)\n",
    "\n",
    "        if is_mmlu:\n",
    "            axs[i].set_yticks([0.25, 0.5, 0.7])\n",
    "            axs[i].set_ylim(0.2, 0.75)\n",
    "        else:\n",
    "            axs[i].set_yticks([0, 0.2, 0.4, 0.6, 0.8])\n",
    "            axs[i].set_ylim(-0.03, 0.84)\n",
    "\n",
    "        for i in range(1, 4):\n",
    "            axs[i].set_yticklabels([])\n",
    "\n",
    "    if x:\n",
    "        axs[1].set_xlabel('Pretraining compute (FLOPs)', fontsize=12, x=1.05)\n",
    "    \n",
    "    axs[0].set_ylabel(ylabel, fontsize=12)\n",
    "\n",
    "    plt.subplots_adjust(wspace=0.08, hspace=0.1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(2, 4, figsize=(8.5,2.8), dpi=200, sharex=True, sharey=False)\n",
    "emergence_plots(results_mmlu, all_results['base']['mmlu-acc'], [0, 100, 250, 1000], 0.25, axs=axs[0], ylabel='Accuracy', emergence=True, titles=True, x=False, is_mmlu=True)\n",
    "emergence_plots(results_mmlu, all_results['base']['mmlu-acc'], [0, 100, 250, 1000], 0.25, axs=axs[1], ylabel='Accuracy', emergence=False, titles=True, x=True, is_mmlu=True)\n",
    "fig.text(0.51, 1., 'MMLU', ha='center', va='center', fontsize=12)\n",
    "plt.savefig('plots/emergence.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(2, 4, figsize=(8.5,2.8), dpi=200, sharex=True, sharey=False)\n",
    "emergence_plots(results_gsm8k, all_results['base']['gsm8k'], [0, 250, 1000, 5000], 0., axs=axs[0], ylabel='Accuracy', emergence=True, titles=True, x=False, is_mmlu=False)\n",
    "emergence_plots(results_gsm8k, all_results['base']['gsm8k'], [0, 250, 1000, 5000], 0., axs=axs[1], ylabel='Accuracy', emergence=False, titles=True, x=True, is_mmlu=False)\n",
    "fig.text(0.51, 1., 'GSM8K', ha='center', va='center', fontsize=12)\n",
    "plt.savefig('plots/emergence-gsm8k.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Appendix"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Are newer models similar to older, fine-tuned models?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "def split_old_new(results):\n",
    "    results_old = {m: a for m, a in results.items() if color_rule(m) == palette[0]}\n",
    "    results_new = {m: a for m, a in results.items() if color_rule(m) == palette[1]}\n",
    "    return results_old, results_new\n",
    "\n",
    "mmlu_old, mmlu_new = split_old_new(all_results['base']['mmlu-acc'])\n",
    "gsm8k_old, gsm8k_new = split_old_new(all_results['base']['gsm8k'])\n",
    "mmlu_old_ft, _ = split_old_new(results_mmlu[1000])\n",
    "gsm8k_old_ft, _ = split_old_new( results_gsm8k[1000])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 2, figsize=(7.5, 1.8), dpi=150, sharex=True)\n",
    "kwargs = {'xticks': [1e20, 1e21, 1e22, 1e23, 1e24], 'title_fontsize': 12, 'msize': 50, 'color_code': 'date'}\n",
    "plot_bench(ax[0], {**mmlu_old_ft, **mmlu_new}, title='MMLU', **kwargs)\n",
    "plot_bench(ax[1], {**gsm8k_old_ft, **gsm8k_new}, title='GSM8K', **kwargs)\n",
    "\n",
    "ax[0].set_ylabel('Accuracy', fontsize=12)\n",
    "for i in range(2):\n",
    "    ax[i].set_xlabel('Pretraining compute (FLOPs)', fontsize=12)\n",
    "    ax[i].tick_params(axis='both', which='both', length=0, labelsize=12)\n",
    "\n",
    "plt.subplots_adjust(wspace=0.2)\n",
    "\n",
    "legend_elements = [\n",
    "    # plt.Line2D([0], [0], marker='', color='w', label=legend1, markerfacecolor='b', markersize=0),\n",
    "    plt.Line2D([0], [0], marker='o', color='w', label='Fine-tuned old models', markerfacecolor=palette[0], markersize=10),\n",
    "    plt.Line2D([0], [0], marker='o', color='w', label='New models', markerfacecolor=palette[1], markersize=10),\n",
    "]\n",
    "\n",
    "fig.legend(handles=legend_elements, bbox_to_anchor=(0.5, -0.4),\n",
    "        loc='lower center', ncol=3, fontsize=12, frameon=True, columnspacing=0.5, handletextpad=.2,)\n",
    "\n",
    "# save as pdf on plots/\n",
    "plt.savefig('plots/ft_old.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from xgboost import XGBClassifier\n",
    "from sklearn.metrics import balanced_accuracy_score\n",
    "\n",
    "def get_accuracy(set_0, set_1):\n",
    "    X = np.r_[set_0, set_1]\n",
    "    y = np.r_[np.zeros(len(set_0)), np.ones(len(set_1))]\n",
    "\n",
    "    y_tests = []\n",
    "    y_preds = []\n",
    "    for i in range(len(X)):  # leave-one-out\n",
    "        X_train = np.delete(X, i, axis=0)\n",
    "        y_train = np.delete(y, i)\n",
    "        X_test = X[i].reshape(1, -1)\n",
    "        y_test = y[i]\n",
    "        clf = XGBClassifier()\n",
    "        clf.fit(X_train, y_train)\n",
    "        y_pred = clf.predict(X_test)\n",
    "        y_tests.append(y_test)\n",
    "        y_preds.append(y_pred)\n",
    "\n",
    "    acc = balanced_accuracy_score(y_tests, y_preds)\n",
    "    return acc\n",
    "\n",
    "def featurize(results):\n",
    "    compute = np.array([compute_compute(m) for m in results])\n",
    "    return np.c_[np.log10(compute), list(results.values())]\n",
    "\n",
    "mmlu_datasets = {\n",
    "    'old': featurize(mmlu_old),\n",
    "    'new': featurize(mmlu_new),\n",
    "    'old_ft': featurize(mmlu_old_ft),\n",
    "}\n",
    "\n",
    "gsm8k_datasets = {\n",
    "    'old': featurize(gsm8k_old),\n",
    "    'new': featurize(gsm8k_new),\n",
    "    'old_ft': featurize(gsm8k_old_ft),\n",
    "}\n",
    "\n",
    "print('Discriminator test -- using both compute and accuracy')\n",
    "discriminator = {\n",
    "    'mmlu': {\n",
    "        'old': get_accuracy(mmlu_datasets['new'], mmlu_datasets['old']),\n",
    "        'old_ft': get_accuracy(mmlu_datasets['new'], mmlu_datasets['old_ft']),\n",
    "    },\n",
    "    'gsm8k': {\n",
    "        'old': get_accuracy(gsm8k_datasets['new'], gsm8k_datasets['old']),\n",
    "        'old_ft': get_accuracy(gsm8k_datasets['new'], gsm8k_datasets['old_ft']),\n",
    "    },\n",
    "}\n",
    "\n",
    "def wrap_multirrow(text, n=2):\n",
    "    return f'\\\\multirow{{{n}}}{{*}}{{{text:.1f}\\%}}'\n",
    "\n",
    "table = f\"Older models vs & {wrap_multirrow(discriminator['mmlu']['old'] * 100)} & {wrap_multirrow(discriminator['gsm8k']['old'] * 100)} \\\\\\\\\\n\"\n",
    "table += f\"newer models & & \\\\\\\\[0.4em]\\n\"\n",
    "table += f\"Fine-tuned, older models vs & {wrap_multirrow(discriminator['mmlu']['old_ft'] * 100)} & {wrap_multirrow(discriminator['gsm8k']['old_ft'] * 100)} \\\\\\\\\\n\"\n",
    "table += f\"newer models & & \\\\\\\\\"\n",
    "print(table)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "Saturation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rs = {\n",
    "    'leaderboard_mmlu_pro': 0.1,\n",
    "    'leaderboard_gpqa': 0.0,\n",
    "    'leaderboard_musr': 0.0,\n",
    "    'leaderboard_bbh': 0.3,\n",
    "    'leaderboard_math_hard': 0.0,\n",
    "}\n",
    "\n",
    "labels = {\n",
    "    'leaderboard_mmlu_pro': 'MMLU Pro',\n",
    "    'leaderboard_gpqa': 'GPQA',\n",
    "    'leaderboard_musr': 'MuSR',\n",
    "    'leaderboard_bbh': 'BBH',\n",
    "    'leaderboard_math_hard': 'MATH Lvl 5',\n",
    "    'leaderboard_mmlu_pro_four': 'MMLU Pro\\n(4 choices)',\n",
    "}\n",
    "\n",
    "def make_plot(axs, benchmark, r):\n",
    "    pre = processed[benchmark]['pre']\n",
    "    post = processed[benchmark]['post']\n",
    "\n",
    "    plot_regressor(*get_features(pre), r, ax=axs[0], main=False, linewidth=5)\n",
    "    plot_regressor(*get_features(post), r, ax=axs[1], main=False, linewidth=5)\n",
    "\n",
    "    axs[1].set_yticklabels([])\n",
    "\n",
    "    min_ = min(min(pre.values()), min(post.values()))\n",
    "    max_ = max(max(pre.values()), max(post.values()))\n",
    "    min_ *= 0.95\n",
    "    if min_ < 0.1:\n",
    "        min_ -= 0.01\n",
    "    if i == 0:\n",
    "        min_ -= 0.02\n",
    "    for j in range(2):\n",
    "        axs[j].set_ylim(min_, max_ * 1.05)\n",
    "\n",
    "    axs[0].set_ylabel(labels[benchmark], fontsize=14)\n",
    "    for j in range(2):\n",
    "        axs[j].tick_params(axis='both', which='both', length=0, labelsize=12)\n",
    "\n",
    "fig, all_axs = plt.subplots(5, 2, figsize=(8.6, 9), dpi=150, sharex=True)\n",
    "\n",
    "for i, (benchmark, r) in enumerate(rs.items()):\n",
    "    make_plot(all_axs[i], benchmark, r)\n",
    "\n",
    "legend_elements = [\n",
    "    plt.Line2D([0], [0], marker='', color='w', label=legend1, markerfacecolor='b', markersize=0),\n",
    "    plt.Line2D([0], [0], marker='o', color='w', label=legend2, markerfacecolor=palette[0], markersize=10),\n",
    "    plt.Line2D([0], [0], marker='o', color='w', label=legend3, markerfacecolor=palette[1], markersize=10),\n",
    "]\n",
    "\n",
    "fig.legend(handles=legend_elements, bbox_to_anchor=(0.5, -0.035),\n",
    "        loc='lower center', ncol=3, fontsize=12, frameon=True, columnspacing=0.5, handletextpad=-.2,)\n",
    "\n",
    "for i in range(2):\n",
    "    all_axs[-1, i].set_xlabel('Pretraining compute (FLOPs)', fontsize=12)\n",
    "\n",
    "all_axs[0, 0].set_title('Unadjusted', fontsize=14)\n",
    "all_axs[0, 1].set_title('Adjusted', fontsize=14)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.subplots_adjust(wspace=0.05, hspace=0.1)\n",
    "plt.savefig('plots/hf-v2.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, all_axs = plt.subplots(1, 2, figsize=(8.6, 2.5), dpi=150, sharex=True)\n",
    "\n",
    "make_plot(all_axs, 'leaderboard_mmlu_pro_four', 0.25)\n",
    "\n",
    "legend_elements = [\n",
    "    plt.Line2D([0], [0], marker='', color='w', label=legend1, markerfacecolor='b', markersize=0),\n",
    "    plt.Line2D([0], [0], marker='o', color='w', label=legend2, markerfacecolor=palette[0], markersize=10),\n",
    "    plt.Line2D([0], [0], marker='o', color='w', label=legend3, markerfacecolor=palette[1], markersize=10),\n",
    "]\n",
    "\n",
    "fig.legend(handles=legend_elements, bbox_to_anchor=(0.5, -0.1),\n",
    "        loc='lower center', ncol=3, fontsize=12, frameon=True, columnspacing=0.5, handletextpad=-.2,)\n",
    "\n",
    "for i in range(2):\n",
    "    all_axs[i].set_xlabel('Pretraining compute (FLOPs)', fontsize=12)\n",
    "\n",
    "all_axs[0].set_title('Unadjusted', fontsize=14)\n",
    "all_axs[1].set_title('Adjusted', fontsize=14)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.subplots_adjust(wspace=0.2)\n",
    "plt.savefig('plots/hf-v2-mmlupro-4.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Tables for the causal interpretation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_post = {\n",
    "    'MMLU': regress_seg(*get_features(mmlu_post)),\n",
    "    'GSM8K': regress_seg(*get_features(gsm8k_post)),\n",
    "}\n",
    "\n",
    "diff_mmlu = {m: mmlu_pre[m] - mmlu_post[m] for m in mmlu_pre}\n",
    "diff_gsm8k = {m: gsm8k_pre[m] - gsm8k_post[m] for m in gsm8k_pre}\n",
    "\n",
    "results_diff = {\n",
    "    'MMLU': regress_seg(*get_features(diff_mmlu)),\n",
    "    'GSM8K': regress_seg(*get_features(diff_gsm8k)),\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def add_bold(str, p):\n",
    "    if p < 0.05:\n",
    "        return '\\\\textbf{'+str+'}'\n",
    "    return str\n",
    "    \n",
    "def format_number(num, n=2):\n",
    "    # three decimal places\n",
    "    num = float(num)\n",
    "    if num < 1e-3 and num > 0:\n",
    "        num = abs(num)\n",
    "    return '{0:.3f}'.format(num)\n",
    "    \n",
    "table = ''\n",
    "table += '\\\\begin{tabular}{c c c}\\n'\n",
    "table += '\\\\toprule\\n'\n",
    "table += '\\\\phantom{**} & \\\\phantom{*}MMLU\\\\phantom{*} & \\\\phantom{*}GSM8K\\\\phantom{*}' + ' \\\\\\\\\\n'\n",
    "table += '\\\\hline\\\\\\\\[-0.9em]\\n'\n",
    "\n",
    "def add_multirow(v):\n",
    "    return '\\\\multirow{2}{*}{'+v+'}'\n",
    "\n",
    "# add i = 1 of results_post\n",
    "row = [f'{format_number(results_post[k][\"coeffs\"][1], 3)}' for k in ['MMLU', 'GSM8K']]\n",
    "row = [add_bold(r, results_post[k][\"ps\"][1]) for r, k in zip(row, ['MMLU', 'GSM8K'])]\n",
    "table += add_multirow('$\\\\hat{\\\\psi}$') + ' & ' + ' & '.join(row) + ' \\\\\\\\\\n'\n",
    "# add standard errors\n",
    "row = [f'({format_number(results_post[k][\"stderrs\"][1], 2)})' for k in ['MMLU', 'GSM8K']]\n",
    "table += '& ' + ' & '.join(row) + ' \\\\\\\\[0.3em]\\n'\n",
    "\n",
    "# add r2 of results_post\n",
    "row = [format_number(results_post[k][\"r2\"]) for k in ['MMLU', 'GSM8K']]\n",
    "row = [r for r in row]\n",
    "table += 'R$^2$ &' + ' & '.join(row) + ' \\\\\\\\\\\\n'\n",
    "table += '\\\\bottomrule\\\\\\\\[-0.8em]\\n'\n",
    "table += '\\\\end{tabular}'\n",
    "print(table)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "table = ''\n",
    "table += '\\\\begin{tabular}{c c c}\\n'\n",
    "table += '\\\\toprule\\n'\n",
    "# first is dataset names\n",
    "\n",
    "table += '\\\\phantom{**} & \\\\phantom{*}MMLU\\\\phantom{*} & \\\\phantom{*}GSM8K\\\\phantom{*}' + ' \\\\\\\\\\n'\n",
    "table += '\\\\hline\\\\\\\\[-0.9em]\\n'\n",
    "\n",
    "def add_multirow(v):\n",
    "    return '\\\\multirow{2}{*}{'+v+'}'\n",
    "\n",
    "# # add i = 1 of results_diff\n",
    "row = [f'{format_number(results_diff[k][\"coeffs\"][1], 3)}' for k in ['MMLU', 'GSM8K']]\n",
    "row = [add_bold(r, results_diff[k][\"ps\"][1]) for r, k in zip(row, ['MMLU', 'GSM8K'])]\n",
    "table += add_multirow('$\\\\hat{\\\\phi}$') + ' & ' + ' & '.join(row) + ' \\\\\\\\\\n'\n",
    "# add standard errors\n",
    "row = [f'({format_number(results_diff[k][\"stderrs\"][1], 2)})' for k in ['MMLU', 'GSM8K']]\n",
    "table += '&' + ' & '.join(row) + ' \\\\\\\\[0.3em]\\n'\n",
    "# add r2 of results_diff\n",
    "row = [format_number(results_diff[k][\"r2\"]) for k in ['MMLU', 'GSM8K']]\n",
    "row = [r for r in row]\n",
    "table += 'R$^2$ & ' + ' & '.join(row) + ' \\\\\\\\\\n'\n",
    "\n",
    "table += '\\\\bottomrule\\\\\\\\[-0.8em]\\n'\n",
    "table += '\\\\end{tabular}'\n",
    "print(table)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "hug",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
