{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e355b99d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from lib_project.notebook import setup_notebook\n",
    "setup_notebook(\"../../../\")\n",
    "               \n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acbd71c6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from functools import partial\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "from experiments.practical_memorization_dynamics import results as res_util\n",
    "from experiments.memorization_dynamics import results as md_res_util"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e1206ca3-0050-4a98-9e42-355d6aeaab2c",
   "metadata": {
    "tags": []
   },
   "source": [
    "# Practical Memorization Dynamics: Different Batch Sizes\n",
    "\n",
    "We train pretrained and untrained models to memorize random strings of length 1024 with different alphabet sizes (2, 7, 26).\n",
    "We present the strings to the model inside batches of different sizes, where one sequence in the batch is the random string, and the others are taken from a natural language dataset ([wikitext-103-raw-v1](https://huggingface.co/datasets/wikitext)).\n",
    "We use batch sizes of 1, 4, 16 and 64.\n",
    "Only the random string is repeated, the natural data sequences change in each iteration."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a35828d-0e51-4f7c-9f96-e2e20545394e",
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCH_SIZES = [1, 4, 16, 64]\n",
    "ALPHABET_SIZES = [2, 7, 26]\n",
    "LEGEND_TITLE = \"Batch Size\"\n",
    "\n",
    "def load_results(\n",
    "    model: str,\n",
    "    alphabet_size: int,\n",
    "    seed_ids: list[int],\n",
    "    batch_sizes: list[int] = BATCH_SIZES,\n",
    ") -> dict:\n",
    "    results = {\n",
    "        f\"BS = {batch_size}\": res_util.load(\n",
    "            f\"{model}_a-{alphabet_size}_t-1024_c-wiki_b-{batch_size}\",\n",
    "            seed_ids,\n",
    "        )\n",
    "        for batch_size in batch_sizes\n",
    "    }\n",
    "    results = res_util.convert_to_discrete_training_steps(results)\n",
    "    return results\n",
    "\n",
    "def show_results(results: dict):\n",
    "    return res_util.show_dynamics(\n",
    "        results,\n",
    "        LEGEND_TITLE,\n",
    "        show_kld=False,\n",
    "        show_in_context_learning=False,\n",
    "    )\n",
    "\n",
    "show_multi_results = partial(\n",
    "    res_util.show_alphabet_pretraining_results,\n",
    "    load_results,\n",
    "    show_results,\n",
    ")\n",
    "\n",
    "def show_multi_results(\n",
    "    produce_paper_plots,\n",
    "    model: str,\n",
    "    seed_ids: list[int],\n",
    "    batch_sizes: list[int] = BATCH_SIZES,\n",
    "    alphabet_sizes: list[int] = ALPHABET_SIZES,\n",
    "    pretrained: bool = True,\n",
    "    untrained: bool = True,\n",
    "):\n",
    "    res_util.show_alphabet_pretraining_results(\n",
    "        partial(load_results, batch_sizes=batch_sizes),\n",
    "        show_results,\n",
    "        produce_paper_plots,\n",
    "        model,\n",
    "        seed_ids,\n",
    "        alphabet_sizes,\n",
    "        pretrained,\n",
    "        untrained,\n",
    "    )\n",
    "\n",
    "produce_paper_plots = partial(\n",
    "    res_util.produce_paper_plots,\n",
    "    figure_folder=\"mem_dynamics_rw_validation/batch_size\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3942594f-5caf-4e14-92ac-9c3be7cd91f9",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Pythia-1B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4abbf80-8306-4b83-a485-39dc20c8fc97",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_ids = list(range(6))\n",
    "show_multi_results(\n",
    "    produce_paper_plots,\n",
    "    # None,\n",
    "    \"pyt-1b\",\n",
    "    seed_ids,\n",
    "    batch_sizes=[1, 4, 16, 64],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2e5f584-6562-4318-8a64-cd8f2bc90117",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# seed_ids = list(range(3))\n",
    "# results = results = load_results(\n",
    "#     \"pyt-1b\",\n",
    "#     2,\n",
    "#     seed_ids,\n",
    "#     # batch_sizes=[1, 4],\n",
    "# )   \n",
    "# pyt_1b_figures = show_results(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc48c319-2ebb-440c-9ab9-aab0d027757a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# np.exp(results[\"b = 16\"][0].value.training_log[\"eval_loss\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5789a54f-fe86-488e-9f3c-d4932ed80c47",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Llama2-13B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf8673f9-8c18-473d-80dd-d92398790c79",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "seed_ids = list(range(5))\n",
    "show_multi_results(\n",
    "    produce_paper_plots,\n",
    "    # None,\n",
    "    \"llama2-13b\",\n",
    "    seed_ids,\n",
    "    batch_sizes=[1, 4, 16],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c61392a4-cad8-412d-bd03-ed525213a7b7",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Phi-2 (2.7B parameters)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "462bbf7d-d6e4-4bdc-8c11-b692de1df638",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_ids = list(range(5))\n",
    "show_multi_results(\n",
    "    produce_paper_plots,\n",
    "    # None,\n",
    "    \"phi-2.7b\",\n",
    "    seed_ids,\n",
    "    batch_sizes=[1, 4, 16],\n",
    "    alphabet_sizes=[2, 7, 26],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94f0bc98-5ace-4c16-b155-19b34e145299",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3341cec-0464-4638-a04a-ba103bad739c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Upload\n",
    "res_util.publish(\"dynamics_analysis/batch_size\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de626a7a-9492-4fe0-ad94-b525e91d1c84",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
