{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e355b99d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from lib_project.notebook import setup_notebook\n",
    "setup_notebook(\"../../../\")\n",
    "               \n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acbd71c6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "from experiments.memorization_dynamics import results as res_util"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e1206ca3-0050-4a98-9e42-355d6aeaab2c",
   "metadata": {
    "tags": []
   },
   "source": [
    "# Memorization Dynamics: String Partitions\n",
    "\n",
    "We train models for 100 epochs on random strings with different alphabet sizes and 1024 tokens."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a35828d-0e51-4f7c-9f96-e2e20545394e",
   "metadata": {},
   "outputs": [],
   "source": [
    "PARTITIOANONYMOUS = [2, 4, 8, 16, 32, 64]\n",
    "LEGEND_TITLE = \"# Partitions\"\n",
    "VARIATION_DIMEANONYMOUSION = \"partitions\"\n",
    "FIGURE_FOLDER = f\"{res_util.FIGURE_FOLDER}/partitions\"\n",
    "\n",
    "def load_results(\n",
    "    model: str,\n",
    "    seed_ids: list[int],\n",
    "    alphabet_size: int,\n",
    "    partitions: list[int] = PARTITIOANONYMOUS,\n",
    ") -> dict:\n",
    "    return {\n",
    "        f\"n = 1\": res_util.load(\n",
    "            [f\"rand_a-lat-{alphabet_size}_t-1024\", model],\n",
    "            seed_ids,\n",
    "        )\n",
    "    } | {\n",
    "        f\"n = {num_partitions}\": res_util.load(\n",
    "            [f\"rand_a-lat-{alphabet_size}_t-1024_p-{num_partitions}\", model],\n",
    "            seed_ids,\n",
    "        )\n",
    "        for num_partitions in partitions\n",
    "    }\n",
    "\n",
    "def show(results: dict):\n",
    "    return res_util.show_dynamics(\n",
    "         results,\n",
    "         LEGEND_TITLE,\n",
    "    )\n",
    "\n",
    "def produce_paper_plots(\n",
    "    figures: dict,\n",
    "    model: str,\n",
    "    alphabet_size: int,\n",
    "    show_legend: bool = False,\n",
    "    show_discrepancy: bool = False,\n",
    "):\n",
    "    res_util.produce_accuracy_paper_plot(\n",
    "        figures,\n",
    "        model,\n",
    "        variation_dimension=f\"{VARIATION_DIMEANONYMOUSION}_a-{alphabet_size}\",\n",
    "        show_legend=show_legend,\n",
    "        figure_folder=FIGURE_FOLDER,\n",
    "    )\n",
    "    res_util.produce_loss_paper_plot(\n",
    "        figures,\n",
    "        model,\n",
    "        variation_dimension=f\"{VARIATION_DIMEANONYMOUSION}_a-{alphabet_size}\",\n",
    "        show_legend=show_legend,\n",
    "        figure_folder=FIGURE_FOLDER,\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c422e774-0519-40c5-98eb-c5c060ad253a",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Pythia-1B, alphabet size 26"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a471c41-6f74-4d00-85b6-a180fe0128bc",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "seed_ids = list(range(5))\n",
    "results = results = load_results(\n",
    "    \"pythia-1b\",\n",
    "    seed_ids,\n",
    "    alphabet_size=26,\n",
    ")   \n",
    "pyt_1b_a26_figures = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ebf986ab-a92b-461b-9dda-6dfed8b76685",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    pyt_1b_a26_figures,\n",
    "    \"pythia-1b\",\n",
    "    alphabet_size=26,\n",
    "    show_legend=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d4a013b8-35b1-4d0d-a6ed-9b4e543408b0",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Pythia-1B, alphabet size 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "528e68a6-3570-4717-b46a-665d4e5308e4",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "seed_ids = list(range(5))\n",
    "results = results = load_results(\n",
    "    \"pythia-1b\",\n",
    "    seed_ids,\n",
    "    alphabet_size=2,\n",
    ")   \n",
    "pyt_1b_a2_figures = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c61efdc6-e54f-4c60-8129-fec09ce4fed7",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    pyt_1b_a2_figures,\n",
    "    \"pythia-1b\",\n",
    "    alphabet_size=2,\n",
    "    show_legend=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d6f477dc-f85b-4eae-bd45-daba680d7541",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Llama2-13B, alphabet size 26"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d3d9459-84bf-437f-8908-ebee9bb60c92",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "seed_ids = list(range(5))\n",
    "results = load_results(\n",
    "    \"Llama-2-13b-hf\",\n",
    "    seed_ids,\n",
    "    alphabet_size=26,\n",
    ")   \n",
    "llama2_13b_a26_figures = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e735dcb0-11d4-4a75-8e5a-14f16f3255c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    llama2_13b_a26_figures,\n",
    "    \"llama2-13b\",\n",
    "    alphabet_size=26,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4d0f552e-69b4-4790-a0c4-b8c74c171db3",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Llama2-13B, alphabet size 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b582d3b-841a-4ccb-9edb-71fda3f57e1c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "seed_ids = list(range(5))\n",
    "results = load_results(\n",
    "    \"Llama-2-13b-hf\",\n",
    "    seed_ids,\n",
    "    alphabet_size=2,\n",
    ")   \n",
    "llama2_13b_a2_figures = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c37f19a8-d2ca-4f25-8b2f-c480fc1170de",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    llama2_13b_a2_figures,\n",
    "    \"llama2-13b\",\n",
    "    alphabet_size=2,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e1b39d60-244f-4b9d-b9d0-138715eceab4",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Phi-2 (2.7B parameters), alphabet size 26"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5740bc6f-34b0-402c-8da3-62e639923cce",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "seed_ids = list(range(5))\n",
    "results = load_results(\n",
    "    \"phi-2\",\n",
    "    seed_ids,\n",
    "    alphabet_size=26,\n",
    ")\n",
    "phi_2_a26_figures = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "996b1205-e8cd-40db-9ba3-74f4b1361f24",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    phi_2_a26_figures,\n",
    "    \"phi-2.7b\",\n",
    "    alphabet_size=26,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8887602a-09c6-45f5-a63b-0145c3f347b1",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Phi-2 (2.7B parameters), alphabet size 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f42700da-1b8e-4bf4-a77f-153664605b1c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "seed_ids = list(range(5))\n",
    "results = load_results(\n",
    "    \"phi-2\",\n",
    "    seed_ids,\n",
    "    alphabet_size=2,\n",
    ")\n",
    "phi_2_a2_figures = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0260861-d17f-45a1-b60d-62861473e25d",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    phi_2_a2_figures,\n",
    "    \"phi-2.7b\",\n",
    "    alphabet_size=2,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fab40316-bb29-4f77-9721-408f7b6382fd",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3341cec-0464-4638-a04a-ba103bad739c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Upload\n",
    "res_util.publish(\"dynamics_analysis/partitions\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de626a7a-9492-4fe0-ad94-b525e91d1c84",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
