{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e355b99d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from lib_project.notebook import setup_notebook\n",
    "setup_notebook(\"../../../../\")\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acbd71c6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "from experiments.memorability.repeated_strings import results as res_util\n",
    "from experiments.memorization_dynamics import results as mem_res_util"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e1206ca3-0050-4a98-9e42-355d6aeaab2c",
   "metadata": {
    "tags": []
   },
   "source": [
    "# Repeated (Sub)Strings: Fraction of Independently Sampled Tokens\n",
    "\n",
    "We sample random strings of length $u$, and then repeat them multiple times to create strings of length $1024$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f7d29cf-d7af-4f31-a9f9-38cb675f847a",
   "metadata": {},
   "outputs": [],
   "source": [
    "SUBSTRING_LENGTHS = [16, 32, 64, 128, 256, 512]\n",
    "LEGEND_TITLE = \"Substring Length\"\n",
    "VARIATION_DIMEANONYMOUSION = \"substring-length\"\n",
    "FIGURE_FOLDER = f\"{mem_res_util.FIGURE_FOLDER}/unique_substring_length\"\n",
    "\n",
    "def load_results(\n",
    "    model: tuple[str, str],\n",
    "    seed_ids: list[int],\n",
    "    alphabet_size: int,\n",
    "    substring_lengths: list[int] = SUBSTRING_LENGTHS,\n",
    ") -> dict:\n",
    "    return {\n",
    "        f\"u = {substring_length}\": res_util.load(\n",
    "            f\"{model[0]}_a-{alphabet_size}_sl-{substring_length}_ns-1_plo-iterative\",\n",
    "            seed_ids,\n",
    "        )\n",
    "        for substring_length in substring_lengths\n",
    "    } | {\n",
    "        \"u = 1024\": mem_res_util.load(\n",
    "            [f\"rand_a-lat-{alphabet_size}_t-1024\", model[1]],\n",
    "            seed_ids,\n",
    "        )\n",
    "    }\n",
    "\n",
    "def show(results: dict):\n",
    "    return mem_res_util.show_dynamics(\n",
    "         results,\n",
    "         LEGEND_TITLE,\n",
    "        show_kld=False,\n",
    "    )\n",
    "\n",
    "def produce_paper_plots(\n",
    "    figures: dict,\n",
    "    model: str,\n",
    "    alphabet_size: int,\n",
    "    show_legend: bool = False,\n",
    "    show_discrepancy: bool = False,\n",
    "):\n",
    "    mem_res_util.produce_accuracy_paper_plot(\n",
    "        figures,\n",
    "        model,\n",
    "        variation_dimension=f\"{VARIATION_DIMEANONYMOUSION}_a-{alphabet_size}\",\n",
    "        show_legend=show_legend,\n",
    "        figure_folder=FIGURE_FOLDER,\n",
    "    )\n",
    "    mem_res_util.produce_loss_paper_plot(\n",
    "        figures,\n",
    "        model,\n",
    "        variation_dimension=f\"{VARIATION_DIMEANONYMOUSION}_a-{alphabet_size}\",\n",
    "        show_legend=show_legend,\n",
    "        figure_folder=FIGURE_FOLDER,\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "82442c7d-3061-408c-ab97-be3c419f627e",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Pythia-1B, alphabet size 26"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33e2db7e-4023-4344-b88c-ede485bba08c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "seed_ids = list(range(5))\n",
    "results = results = load_results(\n",
    "    (\"pyt-1b\", \"pythia-1b\"),\n",
    "    seed_ids,\n",
    "    alphabet_size=26,\n",
    ")\n",
    "pyt_1b_a26_figures = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ed85557-edcc-4600-ae47-b043ad78dbf9",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    pyt_1b_a26_figures,\n",
    "    \"pythia-1b\",\n",
    "    alphabet_size=26,\n",
    "    show_legend=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f2693d91-a755-4524-a2bb-38013947880d",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Pythia-1B, alphabet size 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5be42967-a629-409c-8e35-3bccbf6ee97c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "seed_ids = list(range(5))\n",
    "results = results = load_results(\n",
    "    (\"pyt-1b\", \"pythia-1b\"),\n",
    "    seed_ids,\n",
    "    alphabet_size=2,\n",
    ")   \n",
    "pyt_1b_a2_figures = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a419b38-b402-4cc4-bf35-a311642ad684",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    pyt_1b_a2_figures,\n",
    "    \"pythia-1b\",\n",
    "    alphabet_size=2,\n",
    "    show_legend=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "65d4d932-d1d5-4299-b647-8230e0f44582",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Llama2-13B, alphabet size 26"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab46ff36-182e-4a9f-b061-386aea081143",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "seed_ids = list(range(5))\n",
    "results = load_results(\n",
    "    (\"llama2-13b\", \"Llama-2-13b-hf\"),\n",
    "    seed_ids,\n",
    "    alphabet_size=26,\n",
    ")   \n",
    "llama2_13b_a26_figures = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a6d453b-8807-4b41-8201-b9437be9803f",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    llama2_13b_a26_figures,\n",
    "    \"llama2-13b\",\n",
    "    alphabet_size=26,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "134e7a9e-1983-41db-96e9-2dcd6d2ed640",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Llama2-13B, alphabet size 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b67dd320-fa57-421d-97d8-4d30294480d0",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "seed_ids = list(range(5))\n",
    "results = load_results(\n",
    "    (\"llama2-13b\", \"Llama-2-13b-hf\"),\n",
    "    seed_ids,\n",
    "    alphabet_size=2,\n",
    ")   \n",
    "llama2_13b_a2_figures = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "282b87b0-320a-4c09-b5ed-58f8f357f548",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    llama2_13b_a2_figures,\n",
    "    \"llama2-13b\",\n",
    "    alphabet_size=2,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "77ca535d-fdee-4212-8760-2ff6ed782d5a",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Phi-2 (2.7B parameters), alphabet size 26"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19188389-aaa1-4f4b-b293-c128c5ed6abd",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "seed_ids = list(range(5))\n",
    "results = load_results(\n",
    "    (\"phi-2.7b\", \"phi-2\"),\n",
    "    seed_ids,\n",
    "    alphabet_size=26,\n",
    ")\n",
    "phi_2_a26_figures = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "234422df-970c-46cc-bd9a-7e0d9eb40647",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    phi_2_a26_figures,\n",
    "    \"phi-2.7b\",\n",
    "    alphabet_size=26,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "de9025d1-137e-46ea-b848-d19e89f48972",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Phi-2 (2.7B parameters), alphabet size 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d6e7b11-afe5-4c5a-a218-16a3b682c6c4",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "seed_ids = list(range(5))\n",
    "results = load_results(\n",
    "    (\"phi-2.7b\", \"phi-2\"),\n",
    "    seed_ids,\n",
    "    alphabet_size=2,\n",
    ")\n",
    "phi_2_a2_figures = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6014aff-c55f-4a9a-bbf6-fa6d3ddb4cd7",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    phi_2_a2_figures,\n",
    "    \"phi-2.7b\",\n",
    "    alphabet_size=2,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7850eb6-6bc3-411d-bffc-227e48dbe1f5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3341cec-0464-4638-a04a-ba103bad739c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Upload\n",
    "res_util.publish(\"frac_independent_tokens\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de626a7a-9492-4fe0-ad94-b525e91d1c84",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
