{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e355b99d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from lib_project.notebook import setup_notebook\n",
    "setup_notebook(\"../../../\")\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acbd71c6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from IPython.display import Markdown as md\n",
    "\n",
    "from lib_project.visualization import with_paper_style\n",
    "from defs import BASE_FIGURE_DIR\n",
    "from experiments.repeated_training import results as res_util\n",
    "from experiments.memorization_dynamics import results as mem_res_util"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e1206ca3-0050-4a98-9e42-355d6aeaab2c",
   "metadata": {
    "tags": []
   },
   "source": [
    "# Memorizing strings in sequence, with untrained models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53b69518-09ce-4612-afcf-083de820bf81",
   "metadata": {},
   "outputs": [],
   "source": [
    "ALPHABET_SIZES = [2, 26]\n",
    "LEGEND_TITLE = \"Iterations\"\n",
    "VARIATION_DIMEANONYMOUSION = \"alphabet-size\"\n",
    "\n",
    "def load_results(\n",
    "    model: str,\n",
    "    alphabet_size: int,\n",
    "    seed_ids: list[int],\n",
    "    repetitions: int,\n",
    ") -> list:\n",
    "    return res_util.load(f\"{model}-ut_a-{alphabet_size}_x{repetitions}\", seed_ids)\n",
    "\n",
    "def produce_paper_plots(\n",
    "    sequential_figures: dict,\n",
    "    parallel_figures: dict,\n",
    "    model: str,\n",
    "    alphabet_size: int,\n",
    "    repetitions: int,\n",
    "):\n",
    "    res_util.produce_accuracy_paper_plot(\n",
    "        sequential_figures,\n",
    "        model,\n",
    "        f\"a-{alphabet_size}_x{repetitions}_sequential\",\n",
    "        pretrained=False,\n",
    "        size=(1000, 500),\n",
    "    )\n",
    "    res_util.produce_accuracy_paper_plot(\n",
    "        parallel_figures,\n",
    "        model,\n",
    "        f\"a-{alphabet_size}_x{repetitions}_parallel\",\n",
    "        pretrained=False,\n",
    "        filter_runs = [1, 2, 4, 8, 16, 24, 32],\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a9916f4e-e494-4824-8f16-f547841f3f75",
   "metadata": {},
   "source": [
    "## Pythia-1B"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a48a6758-7143-4160-8cdd-8b60939e44fb",
   "metadata": {},
   "source": [
    "### Training on 32 x 1024 token strings, 26 character alphabet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6441815b-c0f5-422f-bb82-0384d6acd439",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_ids = list(range(5))\n",
    "repetitions = 32\n",
    "results = load_results(\n",
    "    \"pyt-1b\",\n",
    "    26,\n",
    "    seed_ids,\n",
    "    repetitions=repetitions,\n",
    ")\n",
    "pythia_a26_seq_figures, pythia_a26_para_figures = res_util.show_string_lengths(\n",
    "    results,\n",
    "    [1024] * repetitions,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02f5857c-a9d0-4f40-b3d3-40ca76239d59",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    pythia_a26_seq_figures,\n",
    "    pythia_a26_para_figures,\n",
    "    \"pythia-1b\",\n",
    "    26,\n",
    "    repetitions=16,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "97cd82f3-163c-4a5a-a754-ef1320c92768",
   "metadata": {},
   "source": [
    "### Training on 32 x 1024 token strings, 2 character alphabet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0dc2bbc4-4815-448a-b663-d911a4b3b617",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_ids = list(range(5))\n",
    "repetitions = 32\n",
    "results = load_results(\n",
    "    \"pyt-1b\",\n",
    "    2,\n",
    "    seed_ids,\n",
    "    repetitions=repetitions,\n",
    ")\n",
    "pythia_a2_seq_figures, pythia_a2_para_figures = res_util.show_string_lengths(\n",
    "    results,\n",
    "    [1024] * repetitions,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80dbc9ba-5290-4cb6-8ed2-7e1c1a9d53be",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    pythia_a2_seq_figures,\n",
    "    pythia_a2_para_figures,\n",
    "    \"pythia-1b\",\n",
    "    2,\n",
    "    repetitions=16,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "337082c5-774b-419d-81ce-c6568daebde4",
   "metadata": {},
   "source": [
    "## Phi-2.7B"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5ba58a53-9621-4b8d-b9a7-05315ccdfd8c",
   "metadata": {},
   "source": [
    "### Training on 16 x 1024 token strings, 26 character alphabet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84a1baad-92df-4208-8bdd-b7282ec04091",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_ids = list(range(5))\n",
    "results = load_results(\n",
    "    \"phi-2.7b\",\n",
    "    26,\n",
    "    seed_ids,\n",
    "    repetitions=16,\n",
    ")\n",
    "phi_a26_seq_figures, phi_a26_para_figures = res_util.show_string_lengths(\n",
    "    results,\n",
    "    [1024] * 16,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f74e16f2-74e2-4fd3-9461-7b523a0485f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    phi_a26_seq_figures,\n",
    "    phi_a26_para_figures,\n",
    "    \"phi-2.7b\",\n",
    "    26,\n",
    "    repetitions=16,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "09f6321c-4545-4bcc-8ee4-834b9b60c98b",
   "metadata": {},
   "source": [
    "### Training on 32 x 1024 token strings, 2 character alphabet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2591f063-c111-4ea0-8ae2-10c1c25ef066",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_ids = list(range(5))\n",
    "results = load_results(\n",
    "    \"phi-2.7b\",\n",
    "    2,\n",
    "    seed_ids,\n",
    "    repetitions=32,\n",
    ")\n",
    "phi_a2_seq_figures, phi_a2_para_figures = res_util.show_string_lengths(\n",
    "    results,\n",
    "    [1024] * 32,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc66a2b0-5cea-42e5-af4b-14cf61879f99",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    phi_a2_seq_figures,\n",
    "    phi_a2_para_figures,\n",
    "    \"phi-2.7b\",\n",
    "    2,\n",
    "    repetitions=32,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "61adfb3a-b364-4f5a-9dac-34fc90ed06c0",
   "metadata": {},
   "source": [
    "## Llama2-13B"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b0142cb8-9da4-4a13-8844-2b73f0d5a72a",
   "metadata": {},
   "source": [
    "### Training on 16 x 1024 token strings, 26 character alphabet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da7ba87c-f2b3-4033-9df7-717a2c03ab55",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_ids = list(range(5))\n",
    "results = load_results(\n",
    "    \"llama2-13b\",\n",
    "    26,\n",
    "    seed_ids,\n",
    "    repetitions=16,\n",
    ")\n",
    "llama_a26_seq_figures, llama_a26_para_figures = res_util.show_string_lengths(\n",
    "    results,\n",
    "    [1024] * 16,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f49f4af-e319-44d5-82d5-eb1c65425e5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    llama_a26_seq_figures,\n",
    "    llama_a26_para_figures,\n",
    "    \"llama2-13b\",\n",
    "    26,\n",
    "    repetitions=16,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c76c8ad8-5aa8-4f72-b104-fc5a17fa9c97",
   "metadata": {},
   "source": [
    "### Training on 16 x 1024 token strings, 2 character alphabet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e857fc7d-29a6-4069-bbd5-563d459634db",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_ids = list(range(5))\n",
    "results = load_results(\n",
    "    \"llama2-13b\",\n",
    "    2,\n",
    "    seed_ids,\n",
    "    repetitions=16,\n",
    ")\n",
    "llama_a2_seq_figures, llama_a2_para_figures = res_util.show_string_lengths(\n",
    "    results,\n",
    "    [1024] * 16,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "edde0127-e8a8-47ee-a54a-139e11353a39",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    llama_a2_seq_figures,\n",
    "    llama_a2_para_figures,\n",
    "    \"llama2-13b\",\n",
    "    2,\n",
    "    repetitions=16,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87856539-29dd-40bf-a392-9e2499cb5188",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3341cec-0464-4638-a04a-ba103bad739c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Upload\n",
    "res_util.publish(\"untrained_32x\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9cf2c77f-d082-4e17-9e0f-7ace2f195854",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
