{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e355b99d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from lib_project.notebook import setup_notebook\n",
    "setup_notebook(\"../../../\")\n",
    "               \n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acbd71c6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from IPython.display import display, Markdown as md\n",
    "\n",
    "from defs import BASE_FIGURE_DIR\n",
    "from experiments.memorization_dynamics import results as mem_res_util\n",
    "from experiments.prefix_mappings import results as res_util"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "54b9aeeb-86f0-4aa0-93cb-24b6300c7407",
   "metadata": {},
   "source": [
    "# How do prefixes of different lengths perform at different stages of training?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94a0c87c-70a0-4ec1-82b6-c472cecdf857",
   "metadata": {},
   "outputs": [],
   "source": [
    "ALPHABET_SIZES = [2, 4, 7, 13, 26]\n",
    "LEGEND_TITLE = \"Alphabet Size\"\n",
    "VARIATION_DIMEANONYMOUSION = \"alphabet-size\"\n",
    "\n",
    "def load_results(\n",
    "    model: tuple[str, str],\n",
    "    seed_ids: list[int],\n",
    "    alphabet_sizes: list[int] = ALPHABET_SIZES,\n",
    ") -> tuple[dict, dict]:\n",
    "    model_short, model_long = model\n",
    "    \n",
    "    mem_results = {}\n",
    "    prefix_mapping_results = {}\n",
    "    for alphabet_size in alphabet_sizes:\n",
    "        res_name = f\"l = {alphabet_size}\"\n",
    "        mem_res = mem_res_util.load(\n",
    "            [f\"rand_a-lat-{alphabet_size}_t-1024\", model_long],\n",
    "            seed_ids,\n",
    "        )\n",
    "        if len(mem_res) > 0:\n",
    "            mem_results[res_name] = mem_res\n",
    "        pm_res = res_util.load(\n",
    "            [\"intermediate_eval\", f\"{model_short}_a-{alphabet_size}_t-1024\"],\n",
    "            seed_ids,\n",
    "        )\n",
    "        if len(pm_res) > 0:\n",
    "            prefix_mapping_results[res_name] = pm_res\n",
    "    return mem_results, prefix_mapping_results\n",
    "\n",
    "def show(results: tuple[dict, dict]):\n",
    "    mem_results, pm_results = results\n",
    "    res_util.show_memorization_dynamics(\n",
    "        mem_results,\n",
    "        LEGEND_TITLE,\n",
    "    )\n",
    "    figs = res_util.show_epoch_prefix_lengths(\n",
    "         pm_results,\n",
    "    )\n",
    "    return figs\n",
    "\n",
    "def produce_paper_plots(\n",
    "    figures: dict,\n",
    "    model: str,\n",
    "):\n",
    "    for i, (config_name, config_fig) in enumerate(figures.items()):\n",
    "        # if config_name == \"l = 2\":\n",
    "        #     continue\n",
    "        res_util.produce_epoch_prefix_paper_plot(\n",
    "            config_fig,\n",
    "            model,\n",
    "            variation_dimension=f\"{VARIATION_DIMEANONYMOUSION}_a-{config_name[4:]}\",\n",
    "            show_legend=i == 4,\n",
    "            figure_folder=\"prefix_mappings/prefix_lengths/pretrained\",\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "45c58ed0-30cb-4902-bd64-ccf912cb0309",
   "metadata": {},
   "source": [
    "## Pythia-1B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe49afea-c0f5-4129-81ac-8698a506fed7",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_ids = list(range(10))\n",
    "results = results = load_results(\n",
    "    (\"pyt-1b\", \"pythia-1b\"),\n",
    "    seed_ids,\n",
    ")   \n",
    "pyt_1b_figures = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8bce04a-a18f-4fd8-9b8d-2510bea2bd9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    pyt_1b_figures,\n",
    "    \"pythia-1b\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "062504d2-163a-4bd0-839c-7888760a418b",
   "metadata": {},
   "source": [
    "## Phi-2.7B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de2be29e-a51d-416e-8a23-77ec9f1784ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_ids = list(range(10))\n",
    "results = results = load_results(\n",
    "    (\"phi-2.7b\", \"phi-2\"),\n",
    "    seed_ids,\n",
    ")   \n",
    "phi_2_figures = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ffd892a8-5e8a-4d40-9fd1-19fa36ad9ae0",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    phi_2_figures,\n",
    "    \"phi-2.7b\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a7662c84-6a16-4fe0-ae69-72ad000980d1",
   "metadata": {},
   "source": [
    "## Llama2-13B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58aca706-3901-445b-b137-cd1a15b57d4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_ids = list(range(5))\n",
    "results = results = load_results(\n",
    "    (\"llama2-13b\", \"Llama-2-13b-hf\"),\n",
    "    seed_ids,\n",
    ")   \n",
    "llama_13b_figures = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3184fa74-fc02-48c3-925d-110a7696eb4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    llama_13b_figures,\n",
    "    \"llama2-13b\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef5d4eb5-a0ab-4cdf-b23f-bab4ae58464d",
   "metadata": {},
   "source": [
    "## GPT2-124M"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cde2fa3e-bb5b-4539-8a7c-c0197e280926",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_ids = list(range(3))\n",
    "results = load_results(\n",
    "    (\"gpt2-124m\", \"gpt2\"),\n",
    "    seed_ids,\n",
    ")\n",
    "gpt2_124m_figures = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84e871be-735f-4557-bce5-8ffefa4bd5c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    gpt2_124m_figures,\n",
    "    \"gpt2-124m\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4df3dc3b-1163-4e51-9997-34aa167c1b5a",
   "metadata": {},
   "source": [
    "## OPT-350M"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee78aba1-0273-4b3b-9429-e807fd9d9477",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_ids = list(range(3))\n",
    "results = load_results(\n",
    "    (\"opt-350m\", \"opt-350m\"),\n",
    "    seed_ids,\n",
    ")\n",
    "opt_figures = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33c94922-0864-4258-bb90-01ab1624c934",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    opt_figures,\n",
    "    \"opt-350m\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "270f5680-9e91-4555-8b67-2124bb37d472",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3341cec-0464-4638-a04a-ba103bad739c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Upload\n",
    "res_util.publish(\"prefix_length/alphabet_size\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55279102-19c2-41eb-9523-d39be1aa0a3d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
