{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e355b99d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from lib_project.notebook import setup_notebook\n",
    "setup_notebook(\"../../../\")\n",
    "               \n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acbd71c6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from functools import partial\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "from experiments.practical_memorization_dynamics import results as res_util\n",
    "from experiments.memorization_dynamics import results as md_res_util"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e1206ca3-0050-4a98-9e42-355d6aeaab2c",
   "metadata": {
    "tags": []
   },
   "source": [
    "# Practical Memorization Dynamics: Different Relative Context Sizes\n",
    "\n",
    "We train pretrained and untrained models to memorize random strings of length 256 with different alphabet sizes (2, 7, 26).\n",
    "We present the strings to the model as substrings of larger sequences of length 256, 512, 1024 and 2048.\n",
    "Each sequence is sampled from a natural language dataset ([wikitext-103-raw-v1](https://huggingface.co/datasets/wikitext)), and we replace 256 of the tokens with those of the random string, at a random position inside the sequence.\n",
    "Only the random string is repeated, the natural data sequences change in each iteration."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a35828d-0e51-4f7c-9f96-e2e20545394e",
   "metadata": {},
   "outputs": [],
   "source": [
    "CONTEXT_SIZES = [1, 2, 4, 8]\n",
    "LEGEND_TITLE = \"Context Size\"\n",
    "\n",
    "def load_results(\n",
    "    model: str,\n",
    "    alphabet_size: int,\n",
    "    seed_ids: list[int],\n",
    "    context_sizes: list[int] = CONTEXT_SIZES,\n",
    ") -> dict:\n",
    "    results = {\n",
    "        f\"CS = {context_size * 256}\": res_util.load(\n",
    "            f\"{model}_a-{alphabet_size}_t-256_c-wiki_x-{context_size}\",\n",
    "            seed_ids,\n",
    "        )\n",
    "        for context_size in context_sizes\n",
    "    }\n",
    "    results = res_util.convert_to_discrete_training_steps(results)\n",
    "    return results\n",
    "\n",
    "def show_results(results: dict):\n",
    "    return res_util.show_dynamics(\n",
    "        results,\n",
    "        LEGEND_TITLE,\n",
    "        show_kld=False,\n",
    "        show_in_context_learning=False,\n",
    "    )\n",
    "show_multi_results = partial(\n",
    "    res_util.show_alphabet_pretraining_results,\n",
    "    load_results,\n",
    "    show_results,\n",
    ")\n",
    "\n",
    "\n",
    "produce_paper_plots = partial(\n",
    "    res_util.produce_paper_plots,\n",
    "    figure_folder=\"mem_dynamics_rw_validation/context_size\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3942594f-5caf-4e14-92ac-9c3be7cd91f9",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Pythia-1B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf292cce-b3bd-4df1-96d8-e2d07bbba296",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_ids = list(range(6))\n",
    "show_multi_results(\n",
    "    produce_paper_plots,\n",
    "    # None,\n",
    "    \"pyt-1b\",\n",
    "    seed_ids,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46ef9ea0-51ca-430c-b9c9-871d9439674d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# results[\"c = 4\"][0].value.training_log"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5789a54f-fe86-488e-9f3c-d4932ed80c47",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Llama2-13B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf8673f9-8c18-473d-80dd-d92398790c79",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "seed_ids = list(range(5))\n",
    "show_multi_results(\n",
    "    produce_paper_plots,\n",
    "    # None,\n",
    "    \"llama2-13b\",\n",
    "    seed_ids,\n",
    "    pretrained=True,\n",
    "    untrained=True,\n",
    "    alphabet_sizes=[2, 7, 26],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c61392a4-cad8-412d-bd03-ed525213a7b7",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Phi-2 (2.7B parameters)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ae935cc-81f5-4cb5-8ce6-1c072d31a2fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_ids = list(range(5))\n",
    "show_multi_results(\n",
    "    produce_paper_plots,\n",
    "    # None,\n",
    "    \"phi-2.7b\",\n",
    "    seed_ids,\n",
    "    untrained=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94f0bc98-5ace-4c16-b155-19b34e145299",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3341cec-0464-4638-a04a-ba103bad739c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Upload\n",
    "res_util.publish(\"dynamics_analysis/context_size\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de626a7a-9492-4fe0-ad94-b525e91d1c84",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
