{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e355b99d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from lib_project.notebook import setup_notebook\n",
    "setup_notebook(\"../../../\")\n",
    "               \n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acbd71c6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "from experiments.memorization_dynamics import results as res_util"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e1206ca3-0050-4a98-9e42-355d6aeaab2c",
   "metadata": {
    "tags": []
   },
   "source": [
    "# Memorization Dynamics: Different Alphabet Sizes\n",
    "\n",
    "We train models for 100 epochs on random strings with different alphabet sizes and 1024 tokens."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a35828d-0e51-4f7c-9f96-e2e20545394e",
   "metadata": {},
   "outputs": [],
   "source": [
    "ALPHABET_SIZES = [2, 4, 7, 13, 26]\n",
    "LEGEND_TITLE = \"Alphabet Size\"\n",
    "VARIATION_DIMEANONYMOUSION = \"alphabet-size\"\n",
    "FIGURE_FOLDER = f\"{res_util.FIGURE_FOLDER}/alphabet_size\"\n",
    "\n",
    "def load_results(\n",
    "    model: str,\n",
    "    seed_ids: list[int],\n",
    "    alphabet_sizes: list[int] = ALPHABET_SIZES,\n",
    ") -> dict:\n",
    "    return {\n",
    "        f\"ℓ = {alphabet_size}\": res_util.load(\n",
    "            [f\"rand_a-lat-{alphabet_size}_t-1024\", model],\n",
    "            seed_ids,\n",
    "        )\n",
    "        for alphabet_size in alphabet_sizes\n",
    "    }\n",
    "\n",
    "def show(results: dict):\n",
    "    return res_util.show_dynamics(\n",
    "         results,\n",
    "         LEGEND_TITLE,\n",
    "    )\n",
    "\n",
    "def produce_paper_plots(\n",
    "    figures: dict,\n",
    "    model: str,\n",
    "    show_legend: bool = False,\n",
    "    show_discrepancy: bool = False,\n",
    "):\n",
    "    res_util.produce_accuracy_paper_plot(\n",
    "        figures,\n",
    "        model,\n",
    "        variation_dimension=VARIATION_DIMEANONYMOUSION,\n",
    "        figure_folder=FIGURE_FOLDER,\n",
    "        show_legend=show_legend,\n",
    "    )\n",
    "    res_util.produce_loss_paper_plot(\n",
    "        figures,\n",
    "        model,\n",
    "        variation_dimension=VARIATION_DIMEANONYMOUSION,\n",
    "        figure_folder=FIGURE_FOLDER,\n",
    "        show_legend=show_legend,\n",
    "    )\n",
    "    res_util.produce_cum_prob_paper_plot(\n",
    "        figures,\n",
    "        model,\n",
    "        variation_dimension=VARIATION_DIMEANONYMOUSION,\n",
    "        figure_folder=FIGURE_FOLDER,\n",
    "        show_legend=show_legend,\n",
    "    )\n",
    "    res_util.produce_entropy_paper_plot(\n",
    "        figures,\n",
    "        model,\n",
    "        variation_dimension=VARIATION_DIMEANONYMOUSION,\n",
    "        figure_folder=FIGURE_FOLDER,\n",
    "        show_legend=show_legend,\n",
    "    )\n",
    "    res_util.produce_kld_paper_plot(\n",
    "        figures,\n",
    "        model,\n",
    "        variation_dimension=VARIATION_DIMEANONYMOUSION,\n",
    "        figure_folder=FIGURE_FOLDER,\n",
    "        show_legend=show_legend,\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "864d6fb0-95a8-45b3-8663-834ccd6951ad",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Pythia-70M"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b45623e-49d9-4bc0-896b-1583ebfe200d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "seed_ids = list(range(10))\n",
    "results = load_results(\n",
    "    \"pythia-70m\",\n",
    "    seed_ids,\n",
    ")   \n",
    "pythia_70m_figures = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d72a40e4-d0f8-4e12-8540-37273ae31df2",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    pythia_70m_figures,\n",
    "    \"pythia-70m\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3942594f-5caf-4e14-92ac-9c3be7cd91f9",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Pythia-1B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c24a7c11-b351-4071-9e63-c3911bfe42e2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "seed_ids = list(range(10))\n",
    "results = results = load_results(\n",
    "    \"pythia-1b\",\n",
    "    seed_ids,\n",
    ")   \n",
    "pyt_1b_figures = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "025d7054-25fd-4db7-816d-5df303c15532",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    pyt_1b_figures,\n",
    "    \"pythia-1b\",\n",
    "    show_legend=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db3c36b1-7716-4fab-abc4-5d9e4499e0a4",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Pythia-12B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75dcd809-af08-4a6c-ab28-d7a974981c17",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "seed_ids = list(range(5))\n",
    "results = results = load_results(\n",
    "    \"pythia-12b\",\n",
    "    seed_ids,\n",
    ")   \n",
    "pythia_12b_figures = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4bdee64c-9d30-4a57-991e-9faa7bb1a6ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    pythia_12b_figures,\n",
    "    \"pythia-12b\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e9bd2236-1095-4105-82a8-2857907bd657",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Llama2-7B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87884481-7633-46d7-a5e4-57cfb7a4a5fd",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "seed_ids = list(range(10))\n",
    "seed_ids = [0, 1, 2, 3, 4]\n",
    "results = results = load_results(\n",
    "    \"Llama-2-7b-hf\",\n",
    "    seed_ids,\n",
    ")   \n",
    "llama2_7b_figures = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37ee0fd8-7619-452e-8717-137343026c21",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    llama2_7b_figures,\n",
    "    \"llama2-7b\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5789a54f-fe86-488e-9f3c-d4932ed80c47",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Llama2-13B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf8673f9-8c18-473d-80dd-d92398790c79",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "seed_ids = list(range(10))\n",
    "results = load_results(\n",
    "    \"Llama-2-13b-hf\",\n",
    "    seed_ids,\n",
    ")   \n",
    "llama2_13b_figures = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b53ef1dd-e2b0-4946-ba2b-14fe1af700c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    llama2_13b_figures,\n",
    "    \"llama2-13b\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "63634362-9aab-4111-92a2-62b3bdb6b2c6",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Phi-1.5 (1.3B parameters)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75932d34-066f-458c-9cdb-9f3395a60d9c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "seed_ids = list(range(10))\n",
    "results = load_results(\n",
    "    \"phi-1_5\",\n",
    "    seed_ids,\n",
    ")\n",
    "phi_1_5_figures = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0af8b6a2-09a6-4973-b674-3da60aca1f83",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    phi_1_5_figures,\n",
    "    \"phi-1.3b\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c61392a4-cad8-412d-bd03-ed525213a7b7",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Phi-2 (2.7B parameters)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05593709-7be2-4ac2-a64a-d56fa0406529",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "seed_ids = list(range(10))\n",
    "results = load_results(\n",
    "    \"phi-2\",\n",
    "    seed_ids,\n",
    ")   \n",
    "phi_2_figures = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9f8bc87-d328-41d6-869e-629df44bc3c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    phi_2_figures,\n",
    "    \"phi-2.7B\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd420316-9a4e-4687-9ab3-e698ecf4f887",
   "metadata": {
    "tags": []
   },
   "source": [
    "## GPT2-124M"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30078bfc-4628-40ae-a030-dcb237413b0c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "seed_ids = list(range(5))\n",
    "results = results = load_results(\n",
    "    \"gpt2\",\n",
    "    seed_ids,\n",
    ")   \n",
    "gpt2_figures = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad90b9f3-8e99-4388-9597-fddf3b5c03c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    gpt2_figures,\n",
    "    \"gpt2-124m\",\n",
    "    show_legend=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4e48e235-52c1-4026-8f7e-204711d49c4e",
   "metadata": {
    "tags": []
   },
   "source": [
    "## GPT2-1.5B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ebf4218-b90e-4312-8fa1-225e9fbea7af",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "seed_ids = list(range(5))\n",
    "results = results = load_results(\n",
    "    \"gpt2-xl\",\n",
    "    seed_ids,\n",
    ")   \n",
    "gpt2_1_5b_figures = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51729430-b347-44c4-aca3-df27ebb9e0ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    gpt2_1_5b_figures,\n",
    "    \"gpt2-1.5b\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef4f29b2-efc3-4ee1-84bf-962a81d0c8b6",
   "metadata": {
    "tags": []
   },
   "source": [
    "## OPT-350M"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40a3609f-3965-44b0-a891-1ea425a85530",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "seed_ids = list(range(5))\n",
    "results = results = load_results(\n",
    "    \"opt-350m\",\n",
    "    seed_ids,\n",
    ")   \n",
    "opt_figures = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fab40316-bb29-4f77-9721-408f7b6382fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    opt_figures,\n",
    "    \"opt-350m\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94f0bc98-5ace-4c16-b155-19b34e145299",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3341cec-0464-4638-a04a-ba103bad739c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Upload\n",
    "res_util.publish(\"dynamics_analysis/alphabet_size\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de626a7a-9492-4fe0-ad94-b525e91d1c84",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
