{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e355b99d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from lib_project.notebook import setup_notebook\n",
    "setup_notebook(\"../../../\")\n",
    "               \n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acbd71c6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "from lib_project.visualization import with_paper_style\n",
    "from defs import BASE_FIGURE_DIR\n",
    "from experiments.memorization_dynamics import results as res_util"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e1206ca3-0050-4a98-9e42-355d6aeaab2c",
   "metadata": {
    "tags": []
   },
   "source": [
    "# Memorization Dynamics: Different Entropy Levels\n",
    "\n",
    "We train models for 100 epochs on random strings with different entropy and 1024 tokens."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b248c03d-9ed6-4d20-9586-cc2372f9468b",
   "metadata": {},
   "outputs": [],
   "source": [
    "ENTROPY_LEVELS = [2, 4, 7, 13, 26]\n",
    "LEGEND_TITLE = \"Entropy Level\"\n",
    "VARIATION_DIMEANONYMOUSION = \"entropy-level\"\n",
    "FIGURE_FOLDER = f\"{res_util.FIGURE_FOLDER}/entropy_level\"\n",
    "\n",
    "def load_results(\n",
    "    model: str,\n",
    "    seed_ids: list[int],\n",
    "    entropy_levels: list[int] = ENTROPY_LEVELS,\n",
    ") -> dict:\n",
    "    results = {\n",
    "        f\"h = H<sub>{entropy_level}</sub>\": res_util.load(\n",
    "            [f\"rand_a-lat-26_h-{entropy_level}_t-1024\", model],\n",
    "            seed_ids,\n",
    "        )\n",
    "        for entropy_level in entropy_levels\n",
    "        if entropy_level != 26\n",
    "    }\n",
    "    if 26 in entropy_levels:\n",
    "        results |= {\n",
    "            \"h = H<sub>26</sub>\": res_util.load(\n",
    "                [f\"rand_a-lat-26_t-1024\", model],\n",
    "                seed_ids,\n",
    "            )\n",
    "        }\n",
    "    return results\n",
    "\n",
    "def show(results: dict):\n",
    "    return res_util.show_dynamics(\n",
    "         results,\n",
    "         LEGEND_TITLE,\n",
    "    )\n",
    "\n",
    "def produce_paper_plots(\n",
    "    figures: dict,\n",
    "    model: str,\n",
    "    show_legend: bool = False,\n",
    "    show_discrepancy: bool = False,\n",
    "):\n",
    "    res_util.produce_accuracy_paper_plot(\n",
    "        figures,\n",
    "        model,\n",
    "        variation_dimension=VARIATION_DIMEANONYMOUSION,\n",
    "        figure_folder=FIGURE_FOLDER,\n",
    "        show_legend=show_legend,\n",
    "    )\n",
    "    res_util.produce_loss_paper_plot(\n",
    "        figures,\n",
    "        model,\n",
    "        variation_dimension=VARIATION_DIMEANONYMOUSION,\n",
    "        figure_folder=FIGURE_FOLDER,\n",
    "        show_legend=show_legend,\n",
    "    )\n",
    "    res_util.produce_cum_prob_paper_plot(\n",
    "        figures,\n",
    "        model,\n",
    "        variation_dimension=VARIATION_DIMEANONYMOUSION,\n",
    "        figure_folder=FIGURE_FOLDER,\n",
    "        show_legend=show_legend,\n",
    "    )\n",
    "    res_util.produce_entropy_paper_plot(\n",
    "        figures,\n",
    "        model,\n",
    "        variation_dimension=VARIATION_DIMEANONYMOUSION,\n",
    "        figure_folder=FIGURE_FOLDER,\n",
    "        show_legend=show_legend,\n",
    "    )\n",
    "    res_util.produce_kld_paper_plot(\n",
    "        figures,\n",
    "        model,\n",
    "        variation_dimension=VARIATION_DIMEANONYMOUSION,\n",
    "        figure_folder=FIGURE_FOLDER,\n",
    "        show_legend=show_legend,\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "864d6fb0-95a8-45b3-8663-834ccd6951ad",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Pythia-70M"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b45623e-49d9-4bc0-896b-1583ebfe200d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "seed_ids = list(range(10))\n",
    "results = load_results(\n",
    "    \"pythia-70m\",\n",
    "    seed_ids,\n",
    ")\n",
    "pythia_70m_figures = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca85afe3-322d-4008-a8b9-ce0fb0e6d5ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    pythia_70m_figures,\n",
    "    \"pythia-70m\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3942594f-5caf-4e14-92ac-9c3be7cd91f9",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Pythia-1B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c24a7c11-b351-4071-9e63-c3911bfe42e2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "seed_ids = list(range(10))\n",
    "results = load_results(\n",
    "    \"pythia-1b\",\n",
    "    seed_ids,\n",
    ")\n",
    "pyt_1b_figures = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "099bd817-5a69-4fb0-8bf2-cc27aa254965",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    pyt_1b_figures,\n",
    "    \"pythia-1b\",\n",
    "    show_legend=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db3c36b1-7716-4fab-abc4-5d9e4499e0a4",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Pythia-12B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75dcd809-af08-4a6c-ab28-d7a974981c17",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "seed_ids = list(range(5))\n",
    "results = load_results(\n",
    "    \"pythia-12b\",\n",
    "    seed_ids,\n",
    ")\n",
    "pythia_12b_figures = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce2e20d7-da72-4c30-b9df-f6dd2553fe4a",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    pythia_12b_figures,\n",
    "    \"pythia-12b\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e9bd2236-1095-4105-82a8-2857907bd657",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Llama2-7B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87884481-7633-46d7-a5e4-57cfb7a4a5fd",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "seed_ids = list(range(10))\n",
    "results = load_results(\n",
    "    \"Llama-2-7b-hf\",\n",
    "    seed_ids,\n",
    ")\n",
    "llama2_7b_figures = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90c2b1f2-a492-4288-abc7-935c60be37b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    llama2_7b_figures,\n",
    "    \"llama2-7b\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5789a54f-fe86-488e-9f3c-d4932ed80c47",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Llama2-13B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf8673f9-8c18-473d-80dd-d92398790c79",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "seed_ids = list(range(5))\n",
    "results = load_results(\n",
    "    \"Llama-2-13b-hf\",\n",
    "    seed_ids,\n",
    ")\n",
    "llama2_13b_figures = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb0a6634-d135-4450-8939-983174be728d",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    llama2_13b_figures,\n",
    "    \"llama2-13b\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "870e602c-bb12-473b-9c9b-c26d7d3c7f5b",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Phi-1.5 (1.3B)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d129e3f-581c-42e7-a1ad-59315f47bde1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "seed_ids = list(range(10))\n",
    "results = load_results(\n",
    "    \"phi-1_5\",\n",
    "    seed_ids,\n",
    ")\n",
    "phi_1_5_figures = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eca17774-f669-41a1-8d40-1341782f83ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    phi_1_5_figures,\n",
    "    \"phi-1.3b\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "70a0def6-2fda-4718-8342-abb765ac38dd",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Phi-2.7B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71909f67-4fd8-4e8b-bc98-a22eb93bf918",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "seed_ids = list(range(10))\n",
    "results = load_results(\n",
    "    \"phi-2\",\n",
    "    seed_ids,\n",
    ")\n",
    "phi_2_figures = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49de284f-bc73-4684-a528-b5b875da42c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    phi_2_figures,\n",
    "    \"phi-2.7b\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e875722a-1508-491a-ae2d-a531e81070ef",
   "metadata": {
    "tags": []
   },
   "source": [
    "## GPT2-124M"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "907e537f-3938-447f-84f4-72e87c56fa09",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "seed_ids = list(range(5))\n",
    "results = results = load_results(\n",
    "    \"gpt2\",\n",
    "    seed_ids,\n",
    ")   \n",
    "gpt2_figures = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2fb0a805-877f-4bee-9875-3ac7e5a24922",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    gpt2_figures,\n",
    "    \"gpt2-124m\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bc7e3779-c534-406a-9d74-27dd8de7ccb2",
   "metadata": {
    "tags": []
   },
   "source": [
    "## GPT2-1.5B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e2005a6-5a13-45a6-beb5-efcc80b58846",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "seed_ids = list(range(5))\n",
    "results = results = load_results(\n",
    "    \"gpt2-xl\",\n",
    "    seed_ids,\n",
    ")   \n",
    "gpt2_1_5b_figures = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bdc65b33-39d0-4cf9-a7ef-36c99161eded",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    gpt2_1_5b_figures,\n",
    "    \"gpt2-1.5b\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4088a51f-161a-491f-875d-2e397820b1b6",
   "metadata": {
    "tags": []
   },
   "source": [
    "## OPT-350M"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b334ded-5dbc-4e29-b507-6160100bebec",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "seed_ids = list(range(5))\n",
    "results = results = load_results(\n",
    "    \"opt-350m\",\n",
    "    seed_ids,\n",
    ")   \n",
    "opt_figures = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d80c7718-5710-48f7-b711-2fbe656edfcd",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    opt_figures,\n",
    "    \"opt-350m\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94f0bc98-5ace-4c16-b155-19b34e145299",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3341cec-0464-4638-a04a-ba103bad739c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Upload\n",
    "res_util.publish(\"dynamics_analysis/entropy_levels\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de626a7a-9492-4fe0-ad94-b525e91d1c84",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
