{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e355b99d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from lib_project.notebook import setup_notebook\n",
    "setup_notebook(\"../../../\")\n",
    "               \n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acbd71c6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from IPython.display import display, Markdown as md\n",
    "\n",
    "from defs import BASE_FIGURE_DIR\n",
    "from experiments.prefix_mappings import results as res_util"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "54b9aeeb-86f0-4aa0-93cb-24b6300c7407",
   "metadata": {},
   "source": [
    "# How is does the replacement strategy for the non-prefix tokens affect performance?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94a0c87c-70a0-4ec1-82b6-c472cecdf857",
   "metadata": {},
   "outputs": [],
   "source": [
    "LEGEND_TITLE = \"Replacement Strategy\"\n",
    "VARIATION_DIMEANONYMOUSION = \"replacement-strategy\"\n",
    "FIGURE_FOLDER = f\"{res_util.FIGURE_FOLDER}/replacement_strategy\"\n",
    "\n",
    "def load_results(\n",
    "    model: str,\n",
    "    seed_ids: list[int],\n",
    ") -> dict:\n",
    "    return {\n",
    "        \"Random\": res_util.load(\n",
    "            [\"intermediate_eval\", f\"{model}_a-26_t-1024\"],\n",
    "            seed_ids,\n",
    "        ),\n",
    "        \"Constant\": res_util.load(\n",
    "            [\"replacement_strategy\", f\"{model}_t-1024_rs-const_id\"],\n",
    "            seed_ids,\n",
    "        ),\n",
    "        # \"Elimination\": res_util.load(\n",
    "        #     [\"size_change\", f\"{model}_t-1024_sc-0\"],\n",
    "        #     seed_ids,\n",
    "        # ),\n",
    "    }\n",
    "\n",
    "def show(results: dict):\n",
    "    return res_util.show_prefix_length_performance(\n",
    "        results,\n",
    "        LEGEND_TITLE,\n",
    "    )\n",
    "\n",
    "def produce_paper_plots(\n",
    "    figure,\n",
    "    model: str,\n",
    "):\n",
    "     res_util.produce_prefix_length_paper_plot(\n",
    "        figure,\n",
    "        model,\n",
    "        variation_dimension=VARIATION_DIMEANONYMOUSION,\n",
    "        figure_folder=FIGURE_FOLDER,\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "45c58ed0-30cb-4902-bd64-ccf912cb0309",
   "metadata": {},
   "source": [
    "## Pythia-1B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3bd081ee-c487-49d1-979d-8d865c81d058",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_ids = list(range(5))\n",
    "results = results = load_results(\n",
    "    \"pyt-1b\",\n",
    "    seed_ids,\n",
    ")   \n",
    "pyt_1b_fig = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8bce04a-a18f-4fd8-9b8d-2510bea2bd9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    pyt_1b_fig,\n",
    "    \"pythia-1b\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1c039c9c-ea63-417e-952f-bd7507562b13",
   "metadata": {},
   "source": [
    "## Phi-2.7B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c80e440e-105f-42f4-9c1e-94b3c7b5621b",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_ids = list(range(5))\n",
    "results = results = load_results(\n",
    "    \"phi-2.7b\",\n",
    "    seed_ids,\n",
    ")   \n",
    "phi_fig = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a361f80-f176-44a0-83b4-09f90db93bea",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    phi_fig,\n",
    "    \"phi-2.7b\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26ae87de-c5ab-4c46-8a31-0ee02dbb0586",
   "metadata": {},
   "source": [
    "## Llama2-13B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9dfdcd95-abae-40a2-9a29-9027b692efdf",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_ids = list(range(5))\n",
    "results = results = load_results(\n",
    "    \"llama2-13b\",\n",
    "    seed_ids,\n",
    ")   \n",
    "llama_fig = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e42b0fa0-3cb6-446d-be49-7bbe2506c452",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    llama_fig,\n",
    "    \"llama2-13b\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef5d4eb5-a0ab-4cdf-b23f-bab4ae58464d",
   "metadata": {},
   "source": [
    "## GPT2-124M"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cde2fa3e-bb5b-4539-8a7c-c0197e280926",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_ids = list(range(5))\n",
    "results = load_results(\n",
    "    \"gpt2-124m\",\n",
    "    seed_ids,\n",
    ")\n",
    "gpt2_124m_fig = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "615415f8-578e-4fc6-b316-fa0149c3e92c",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    gpt2_124m_fig,\n",
    "    \"gpt2-124m\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7e4ae158-fbff-45e5-9bf0-091b562ea775",
   "metadata": {},
   "source": [
    "## OPT-350M"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c32060f0-cc98-4aec-a6a7-b927a3d9dbab",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_ids = list(range(5))\n",
    "results = load_results(\n",
    "    \"opt-350m\",\n",
    "    seed_ids,\n",
    ")\n",
    "opt_fig = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c194aaf5-a6c6-464c-aa07-6e9de640aeb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    opt_fig,\n",
    "    \"opt-350m\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3341cec-0464-4638-a04a-ba103bad739c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Upload\n",
    "res_util.publish(\"prefix_length/replacement_strategy\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55279102-19c2-41eb-9523-d39be1aa0a3d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
