{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e355b99d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from lib_project.notebook import setup_notebook\n",
    "setup_notebook(\"../../../\")\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acbd71c6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from IPython.display import Markdown as md\n",
    "\n",
    "from lib_project.visualization import with_paper_style\n",
    "from defs import BASE_FIGURE_DIR\n",
    "from experiments.repeated_training import results as res_util\n",
    "from experiments.memorization_dynamics import results as mem_res_util"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e1206ca3-0050-4a98-9e42-355d6aeaab2c",
   "metadata": {
    "tags": []
   },
   "source": [
    "# Memorizing strings of different lengths in sequence\n",
    "\n",
    "We train a Pythia 1B model to memorize different random strings for 50 epochs, one after the other.\n",
    "I.e. given strings $s_1, \\dots, s_n$, we first train model $m$ on $s_1$, then the resulting model on $s_2$, etc.\n",
    "We evaluate the memorization dynamics (loss, accuracy, aggregate probability over the alphabet, and entropy) for each string as training proceeds, i.e. also how training on subsequent strings affects previously memorized strings.\n",
    "\n",
    "Here, we train models of strings that differ in their length/number of tokens.\n",
    "We show results for training a model on 8 x 1024 token strings with a 26 token alphabet, as well as on 32 x 1024 token strings with 2 and 26 token alphabets.\n",
    "\n",
    "### Takeaways\n",
    "\n",
    "- Strings memorized later override information about previously memorized strings. The loss drastically increases and the accuracy drastically decreases as soon as another string is memorized.\n",
    "- Training on a string after having trained on other strings before makes it easier to memorize it. When memorizing 8 x 1024 token strings, each subsequent string is memorized faster (see initial loss and accuracy comparison plots), except string 3, which is close to the first string.\n",
    "- Strings later in training seem to be forgotten less when training on subsequent strings than those earlier during training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "990b91c9-a5ce-4a54-8858-515d2fb67735",
   "metadata": {},
   "outputs": [],
   "source": [
    "def produce_paper_plots(\n",
    "    sequential_figures: dict,\n",
    "    parallel_figures: dict,\n",
    "    alphabet_size: int,\n",
    "):\n",
    "    res_util.produce_accuracy_paper_plot(\n",
    "        sequential_figures,\n",
    "        \"pythia-1b\",\n",
    "        f\"a-{alphabet_size}_x32_sequential\",\n",
    "    )\n",
    "    res_util.produce_accuracy_paper_plot(\n",
    "        parallel_figures,\n",
    "        \"pythia-1b\",\n",
    "        f\"a-{alphabet_size}_x32_parallel\",\n",
    "        filter_runs = [0, 7, 15, 23, 31],\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a192f40d-2952-4596-8586-32877ce0d252",
   "metadata": {},
   "source": [
    "## Training on 8 different 1024 token strings\n",
    "\n",
    "Strings have a 26 character alphabet."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0c12d2f-fe38-4e76-8c6d-735b0fd9ab1c",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_ids = [0]\n",
    "results = res_util.load(\"pyt-1b_a-26_t-1024_x8\", seed_ids)\n",
    "res_util.show_string_lengths(\n",
    "    results,\n",
    "    [1024] * 8,\n",
    ")\n",
    "res_util.show_prefix_performance(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9dc285a7-28de-48f1-abb8-6eabacf73107",
   "metadata": {},
   "outputs": [],
   "source": [
    "results[0].value.iteration_results[0].prefix_mappings"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a48a6758-7143-4160-8cdd-8b60939e44fb",
   "metadata": {},
   "source": [
    "## Training on 32 x 1024 token strings, 26 character alphabet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6441815b-c0f5-422f-bb82-0384d6acd439",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_ids = [0]\n",
    "results = res_util.load(\"pyt-1b_a-26_t-1024_x32\", seed_ids)\n",
    "a26_seq_figures, a26_para_figures = res_util.show_string_lengths(\n",
    "    results,\n",
    "    [1024] * 32,\n",
    ")\n",
    "res_util.show_prefix_performance(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02f5857c-a9d0-4f40-b3d3-40ca76239d59",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    a26_seq_figures,\n",
    "    a26_para_figures,\n",
    "    26,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a9296eb-a2fc-4117-abab-9947ca532ea6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# paper_fig = with_paper_style(\n",
    "#     a26_seq_figures[\"training_loss\"],\n",
    "#     # legend_pos=(1, 0),\n",
    "#     # legend_yanchor=\"bottom\",\n",
    "#     # legend_orientation=\"h\",\n",
    "# )\n",
    "# # paper_fig.update_yaxes(range=[-0.05, 10.05])\n",
    "# paper_fig.show()\n",
    "# paper_fig.write_image(BASE_FIGURE_DIR / \"repeated_memorization/a-26_x32_loss.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "514cc8a7-b202-4a32-8465-b07fbdadb1f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "paper_fig = with_paper_style(\n",
    "    a26_seq_igures[\"training_accuracy\"],\n",
    ")\n",
    "paper_fig.update_yaxes(range=[-0.05, 1.05])\n",
    "paper_fig.show()\n",
    "paper_fig.write_image(BASE_FIGURE_DIR / \"repeated_memorization/a-26_x32_accuracy.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eec30443-14d7-454f-bbbc-f2baf4ee9abc",
   "metadata": {},
   "outputs": [],
   "source": [
    "paper_fig = with_paper_style(\n",
    "    a26_figures[\"training_accuracy\"],\n",
    ")\n",
    "# paper_fig.update_yaxes(range=[-0.05, 10.05])\n",
    "paper_fig.show()\n",
    "paper_fig.write_image(BASE_FIGURE_DIR / \"repeated_memorization/a-26_x32_accuracy.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "97cd82f3-163c-4a5a-a754-ef1320c92768",
   "metadata": {},
   "source": [
    "## Training on 32 x 1024 token strings, 2 character alphabet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0dc2bbc4-4815-448a-b663-d911a4b3b617",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_ids = [0]\n",
    "results = res_util.load(\"pyt-1b_a-2_t-1024_x32\", seed_ids)\n",
    "a2_seq_figures, a2_para_figures = res_util.show_string_lengths(\n",
    "    results,\n",
    "    [1024] * 32,\n",
    ")\n",
    "res_util.show_prefix_performance(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80dbc9ba-5290-4cb6-8ed2-7e1c1a9d53be",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    a2_seq_figures,\n",
    "    a2_para_figures,\n",
    "    2,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a62207f0-bb92-4d53-8bcc-e53b1f398128",
   "metadata": {},
   "outputs": [],
   "source": [
    "paper_fig = with_paper_style(\n",
    "    a2_figures[\"training_loss\"],\n",
    ")\n",
    "paper_fig.show()\n",
    "paper_fig.write_image(BASE_FIGURE_DIR / \"repeated_memorization/a-2_x32_loss.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e18990a1-28a1-4050-be60-b1f4822fb9b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "paper_fig = with_paper_style(\n",
    "    a2_figures[\"training_accuracy\"],\n",
    ")\n",
    "paper_fig.show()\n",
    "paper_fig.write_image(BASE_FIGURE_DIR / \"repeated_memorization/a-2_x32_accuracy.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4353577a-4ef8-454b-b126-dd1a8ba75052",
   "metadata": {},
   "source": [
    "## Training on 32 x 1024 token strings, 26 character alphabet, different seed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ea0d2dc-2790-47fa-b9be-d6ba30d94f01",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_ids = [1]\n",
    "results = res_util.load(\"pyt-1b_a-26_t-1024_x32\", seed_ids)\n",
    "res_util.show_string_lengths(\n",
    "    results,\n",
    "    [1024] * 32,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fa20ad7d-6667-41ac-ba4c-8489aedee4de",
   "metadata": {},
   "source": [
    "## Training on 32 x 1024 token strings, 2 character alphabet, different seed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d06be66-4a79-406c-9a00-4fb9f86d19de",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_ids = [1]\n",
    "results = res_util.load(\"pyt-1b_a-2_t-1024_x32\", seed_ids)\n",
    "res_util.show_string_lengths(\n",
    "    results,\n",
    "    [1024] * 32,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87856539-29dd-40bf-a392-9e2499cb5188",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3341cec-0464-4638-a04a-ba103bad739c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Upload\n",
    "res_util.publish(\"same_strings\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9cf2c77f-d082-4e17-9e0f-7ace2f195854",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
