{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e355b99d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from lib_project.notebook import setup_notebook\n",
    "setup_notebook(\"../../../\")\n",
    "               \n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acbd71c6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from IPython.display import Markdown as md\n",
    "\n",
    "from lib_project.visualization import with_paper_style\n",
    "from experiments.repeated_training import results as res_util\n",
    "from experiments.memorization_dynamics import results as mem_res_util"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e1206ca3-0050-4a98-9e42-355d6aeaab2c",
   "metadata": {
    "tags": []
   },
   "source": [
    "# Memorizing strings of different lengths in sequence\n",
    "\n",
    "We train a Pythia 1B model to memorize different random strings for 50 epochs, one after the other.\n",
    "I.e. given strings $s_1, \\dots, s_n$, we first train model $m$ on $s_1$, then the resulting model on $s_2$, etc.\n",
    "We evaluate the memorization dynamics (loss, accuracy, aggregate probability over the alphabet, and entropy) for each string as training proceeds, i.e. also how training on subsequent strings affects previously memorized strings.\n",
    "\n",
    "Here, we train models of strings that differ in their length/number of tokens.\n",
    "We show results for training a model on 8 x 1024 token strings with a 26 token alphabet, as well as on 32 x 1024 token strings with 2 and 26 token alphabets.\n",
    "\n",
    "### Takeaways\n",
    "\n",
    "- Strings memorized later override information about previously memorized strings. The loss drastically increases and the accuracy drastically decreases as soon as another string is memorized.\n",
    "- Training on a string after having trained on other strings before makes it easier to memorize it. When memorizing 8 x 1024 token strings, each subsequent string is memorized faster (see initial loss and accuracy comparison plots), except string 3, which is close to the first string.\n",
    "- Strings later in training seem to be forgotten less when training on subsequent strings than those earlier during training."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4fe49481-2b2f-45ba-b3a3-dd6587f3ffa2",
   "metadata": {},
   "source": [
    "## Training on one 1024 token, 2 character string, for 500 epochs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2796ea4f-848f-4135-834b-30c67c42435a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load(config_name: str):\n",
    "    return [\n",
    "        res_util.load(config_name, [0])[0]\n",
    "            .value\n",
    "            .iteration_results[0]\n",
    "            .memorization_log\n",
    "    ]\n",
    "mem_results = {\n",
    "    f\"String {string_id + 1}\": load(f\"pyt-1b_a-2_t-1024_s-{string_id}\")\n",
    "    for string_id in [0, 26, 31]\n",
    "}\n",
    "figs = mem_res_util.show_dynamics_overview(\n",
    "    mem_results,\n",
    "    title=\"String ID\",\n",
    "    show_cum_prob=False,\n",
    "    show_discrepancy=False,\n",
    "    show_in_context_learning=False,\n",
    ")\n",
    "results = {\n",
    "    f\"String {string_id + 1}\": res_util.load(f\"pyt-1b_a-2_t-1024_s-{string_id}\", [0])\n",
    "    for string_id in [0, 26, 31]\n",
    "}\n",
    "for string_name, string_res in results.items():\n",
    "    display(md(f\"#### {string_name}\"))\n",
    "    res_util.show_prefix_performance(string_res, title=f\"{string_name} Prefix Performance\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87856539-29dd-40bf-a392-9e2499cb5188",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3341cec-0464-4638-a04a-ba103bad739c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Upload\n",
    "res_util.publish(\"long_training\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9cf2c77f-d082-4e17-9e0f-7ace2f195854",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
