{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e355b99d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from lib_project.notebook import setup_notebook\n",
    "setup_notebook(\"../../../\")\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acbd71c6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from IPython.display import display, Markdown as md\n",
    "\n",
    "from defs import BASE_FIGURE_DIR\n",
    "from experiments.prefix_mappings import results as res_util"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "54b9aeeb-86f0-4aa0-93cb-24b6300c7407",
   "metadata": {},
   "source": [
    "# How is prefix performance affected by adding or removing tokens?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94a0c87c-70a0-4ec1-82b6-c472cecdf857",
   "metadata": {},
   "outputs": [],
   "source": [
    "LEGEND_TITLE = \"Size Change\"\n",
    "VARIATION_DIMEANONYMOUSION = \"size-change\"\n",
    "FIGURE_FOLDER = f\"{res_util.FIGURE_FOLDER}/size_change\"\n",
    "\n",
    "def load_results(\n",
    "    model: str,\n",
    "    seed_ids: list[int],\n",
    "    string_length: int = 1024,\n",
    ") -> dict:\n",
    "    return {\n",
    "        f\"{size_change:.1f} x GC\": res_util.load(\n",
    "            [\n",
    "                (\n",
    "                    \"size_change\"\n",
    "                    if size_change != 1\n",
    "                    else \"intermediate_eval\"\n",
    "                ),\n",
    "                (\n",
    "                    f\"{model}_t-{string_length}_sc-{size_change}\"\n",
    "                    if size_change != 1\n",
    "                    else f\"{model}_a-26_t-1024\"\n",
    "                )\n",
    "            ],\n",
    "            seed_ids,\n",
    "        )\n",
    "        for size_change in [0, 0.5, 1, 1.5, 2]\n",
    "    }\n",
    "\n",
    "def show(results: dict):\n",
    "    return res_util.show_prefix_length_performance(\n",
    "        results,\n",
    "        LEGEND_TITLE,\n",
    "    )\n",
    "\n",
    "def produce_paper_plots(\n",
    "    figure,\n",
    "    model: str,\n",
    "):\n",
    "     res_util.produce_prefix_length_paper_plot(\n",
    "        figure,\n",
    "        model,\n",
    "        variation_dimension=VARIATION_DIMEANONYMOUSION,\n",
    "        figure_folder=FIGURE_FOLDER,\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "45c58ed0-30cb-4902-bd64-ccf912cb0309",
   "metadata": {},
   "source": [
    "## Pythia-1B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3bd081ee-c487-49d1-979d-8d865c81d058",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_ids = list(range(5))\n",
    "results = results = load_results(\n",
    "    \"pyt-1b\",\n",
    "    seed_ids,\n",
    ")   \n",
    "pyt_1b_fig = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8bce04a-a18f-4fd8-9b8d-2510bea2bd9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    pyt_1b_fig,\n",
    "    \"pythia-1b\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0dda5910-cccb-4a59-970d-c878fcccc70b",
   "metadata": {},
   "source": [
    "## Phi-2.7B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "221d0db6-44a2-4183-9011-f5971d06ef43",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_ids = list(range(5))\n",
    "results = results = load_results(\n",
    "    \"phi-2.7b\",\n",
    "    seed_ids,\n",
    ")   \n",
    "phi_fig = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8feb00a0-3524-4e1f-ab96-61b29304567e",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    phi_fig,\n",
    "    \"phi-2.7b\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb899edc-1b82-4854-9a1a-10dbb7efe232",
   "metadata": {},
   "source": [
    "## Llama2-13B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b59a4872-d2c4-4482-93d9-2a2057e175cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_ids = list(range(5))\n",
    "results = results = load_results(\n",
    "    \"llama2-13b\",\n",
    "    seed_ids,\n",
    ")   \n",
    "llama_fig = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e750a0d-24c6-477d-82cf-61fb2ea5f6da",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    llama_fig,\n",
    "    \"llama2-13b\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef5d4eb5-a0ab-4cdf-b23f-bab4ae58464d",
   "metadata": {},
   "source": [
    "## GPT2-124M"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cde2fa3e-bb5b-4539-8a7c-c0197e280926",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_ids = list(range(5))\n",
    "results = load_results(\n",
    "    \"gpt2-124m\",\n",
    "    seed_ids,\n",
    "    string_length=512,\n",
    ")\n",
    "gpt2_124m_fig = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96a31f14-2cd9-4180-97a8-9a3b556dc687",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    gpt2_124m_fig,\n",
    "    \"gpt2-124m\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d195175-66d2-499a-9cf6-07524ee469dc",
   "metadata": {},
   "source": [
    "## OPT-350M"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f4f519f-f68f-4f02-9075-a8f2308c65a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_ids = list(range(5))\n",
    "results = load_results(\n",
    "    \"opt-350m\",\n",
    "    seed_ids,\n",
    ")\n",
    "opt_fig = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3743aa61-65c4-417d-9d18-615b5deacff4",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    opt_fig,\n",
    "    \"opt-350m\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9c3751b-50c7-4a56-a3de-052791e919da",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3341cec-0464-4638-a04a-ba103bad739c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Upload\n",
    "res_util.publish(\"prefix_length/size_change\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55279102-19c2-41eb-9523-d39be1aa0a3d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
