{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e355b99d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from lib_project.notebook import setup_notebook\n",
    "setup_notebook(\"../../../\")\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acbd71c6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from IPython.display import display, Markdown as md\n",
    "\n",
    "from experiments.memorization_dynamics import results as mem_res_util\n",
    "from experiments.prefix_mappings import results as res_util"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "54b9aeeb-86f0-4aa0-93cb-24b6300c7407",
   "metadata": {},
   "source": [
    "# How do prefixes of different lengths perform at different stages of training? - Entropy Levels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94a0c87c-70a0-4ec1-82b6-c472cecdf857",
   "metadata": {},
   "outputs": [],
   "source": [
    "ENTROPY_LEVELS = [2, 4, 7, 13]\n",
    "LEGEND_TITLE = \"Entropy Level\"\n",
    "VARIATION_DIMEANONYMOUSION = \"entropy_level\"\n",
    "\n",
    "def load_results(\n",
    "    model: tuple[str, str],\n",
    "    seed_ids: list[int],\n",
    "    entropy_levels: list[int] = ENTROPY_LEVELS,\n",
    ") -> tuple[dict, dict]:\n",
    "    model_short, model_long = model\n",
    "    \n",
    "    mem_results = {}\n",
    "    prefix_mapping_results = {}\n",
    "    for entropy_level in entropy_levels:\n",
    "        res_name = f\"h = {entropy_level}\"\n",
    "        mem_results[res_name] = mem_res_util.load(\n",
    "            [f\"rand_a-lat-26_h-{entropy_level}_t-1024\", model_long],\n",
    "            seed_ids,\n",
    "        )\n",
    "        prefix_mapping_results[res_name] = res_util.load(\n",
    "            [\"intermediate_eval\", f\"{model_short}_h-{entropy_level}_t-1024\"],\n",
    "            seed_ids,\n",
    "        )\n",
    "\n",
    "    res_name = f\"h = 26\"\n",
    "    mem_results[res_name] = mem_res_util.load(\n",
    "        [f\"rand_a-lat-26_t-1024\", model_long],\n",
    "        seed_ids,\n",
    "    )\n",
    "    prefix_mapping_results[res_name] = res_util.load(\n",
    "        [\"intermediate_eval\", f\"{model_short}_a-26_t-1024\"],\n",
    "        seed_ids,\n",
    "    )\n",
    "    \n",
    "    return mem_results, prefix_mapping_results\n",
    "\n",
    "def show(results: tuple[dict, dict]):\n",
    "    mem_results, pm_results = results\n",
    "    res_util.show_memorization_dynamics(\n",
    "        mem_results,\n",
    "        LEGEND_TITLE,\n",
    "    )\n",
    "    figs = res_util.show_epoch_prefix_lengths(\n",
    "         pm_results,\n",
    "    )\n",
    "    return figs\n",
    "\n",
    "def produce_paper_plots(\n",
    "    figures: dict,\n",
    "    model: str,\n",
    "):\n",
    "    for i, (config_name, config_fig) in enumerate(figures.items()):\n",
    "        res_util.produce_epoch_prefix_paper_plot(\n",
    "            config_fig,\n",
    "            model,\n",
    "            variation_dimension=f\"{VARIATION_DIMEANONYMOUSION}_a-{config_name[4:]}\",\n",
    "            show_legend=i == 4,\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "45c58ed0-30cb-4902-bd64-ccf912cb0309",
   "metadata": {},
   "source": [
    "## Pythia-1B\n",
    "\n",
    "TODO: need to rerun because of datatype issues"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe49afea-c0f5-4129-81ac-8698a506fed7",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_ids = [1, 2]\n",
    "results = results = load_results(\n",
    "    (\"pyt-1b\", \"pythia-1b\"),\n",
    "    seed_ids,\n",
    ")   \n",
    "pyt_1b_figures = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8bce04a-a18f-4fd8-9b8d-2510bea2bd9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    pyt_1b_figures,\n",
    "    \"pythia-1b\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef5d4eb5-a0ab-4cdf-b23f-bab4ae58464d",
   "metadata": {},
   "source": [
    "## GPT2-124M"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cde2fa3e-bb5b-4539-8a7c-c0197e280926",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_ids = [1, 2]\n",
    "results = load_results(\n",
    "    (\"gpt2-124m\", \"gpt2\"),\n",
    "    seed_ids,\n",
    ")\n",
    "gpt2_124m_figures = show(results)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a7662c84-6a16-4fe0-ae69-72ad000980d1",
   "metadata": {},
   "source": [
    "## Llama2-13B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58aca706-3901-445b-b137-cd1a15b57d4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_ids = list(range(5))\n",
    "results = results = load_results(\n",
    "    (\"llama2-13bm\", \"Llama-2-13b-hf\"),\n",
    "    seed_ids,\n",
    ")   \n",
    "pyt_1b_figures = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3623d6f0-aa6c-4ebe-93b3-1c97218229e0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3341cec-0464-4638-a04a-ba103bad739c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Upload\n",
    "res_util.publish(\"prefix_length/entropy_level\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55279102-19c2-41eb-9523-d39be1aa0a3d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
