{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e355b99d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from lib_project.notebook import setup_notebook\n",
    "setup_notebook(\"../../../\")\n",
    "               \n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acbd71c6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from IPython.display import display, Markdown as md\n",
    "\n",
    "from defs import BASE_FIGURE_DIR\n",
    "from experiments.memorization_dynamics import results as mem_res_util\n",
    "from experiments.prefix_mappings import results as res_util"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "54b9aeeb-86f0-4aa0-93cb-24b6300c7407",
   "metadata": {},
   "source": [
    "# How do prefixes lengths perform for conditional probability distributions?\n",
    "\n",
    "We want to test how strings with conditional ngram distributions, i.e. where the probability of the next token depends on the preceeding ngram, influences the prefix length of models.\n",
    "\n",
    "By the n-conditional entropy $H_n(s)$ of a string $s$, we refer to the entropy $H_n(s) = H(s_i | s_{i-n}, \\dots, s_{i-1})$, i.e. the entropy over tokens in $s$, that the preceding n tokens, i.e. the preceding n-gram is known.\n",
    "We are interested in knowing whether at the same level of unconditional entropy $H(s)$, i.e. 0-conditional entropy, strings with different levels of n-conditional entropy $H_n(s)$ differ in their memorability.\n",
    "\n",
    "## Methodology\n",
    "\n",
    "**Privileged continuation tokens**:\n",
    "We create string $s$ with alphabet $A$ with a certain level of n-conditional entropy by assigning each possible n-gram $g$ over $A$ a certain *privileged continuation token* $t_g$.\n",
    "E.g. for $A = \\{a, b\\}$, and there are the 2-grams $aa, ab, ba, bb$, and each of them would have a privileged continuation, e.g. $b$ for $aa$, $a$ for $ab$, etc.\n",
    "\n",
    "**Constructing strings with different levels of conditional entropy**:\n",
    "To sample string $s$, we first sample $n$ tokens from $A$ uniformly at random.\n",
    "To sample the next token $s_i$, we get its preceding n-gram $g = s_{i-n}, \\dots, s_{i-1}$, look up its privileged token $t_g$ and then sample a token from $A$ with *$k \\times$ relative probability* $p_k = k * p_u$ for $t_g$, and uniform probability $p_u$ for all other tokens $t \\in A \\setminus \\{t_g\\}$.\n",
    "I.e. we are $k$ times more likely to sample the privileged token $t_g$ as a continuation to $g$ than the other tokens in $A$.\n",
    "We obtain $p_k$ as $p_k = \\frac{k}{|A| -1 + k}$ and $p_u = \\frac{1 - p_k}{|A| - 1}$.\n",
    "Increasing the relative probability $p_k$ lowers the conditional entropy $H_n(s)$ of string $s$.\n",
    "\n",
    "**Ensuring the same level of unconditional entropy**:\n",
    "To ensure that strings with different $p_k$ have the same unconditional entropy $H(s)$, we ensure that each token $t \\in A$ appears the same number of times as a privileged continuation token.\n",
    "I.e. for 1-grams, where there are $|A|$ combinations (single tokens from $A$), each $t \\in A$ appears once as the privileged token of a 1-gram.\n",
    "For 2-grams, with $|A|^2$ possible combinations, each token appears $A$ times as privileged token, etc.\n",
    "E.g. for 2-grams over $A = \\{a, b\\}$ a privileged token mapping $aa \\rightarrow b, ab \\rightarrow b, ba \\rightarrow a, bb \\rightarrow a$ would be valid, whereas the mapping $aa \\rightarrow b, ab \\rightarrow b, ba \\rightarrow b, bb \\rightarrow b$ would be not.\n",
    "Making each token appear the same number of times as privileged continuation ensures that the overall probability of each $t \\in A$ is the same, and thus the unconditional entropy of the strings is the same.\n",
    "\n",
    "**Model training**:\n",
    "As usual, we train models for 100 epochs to memorize strings with alphabets of different sizes (i.e. entropy levels) and record their memorization dynamics.\n",
    "We also compute the eANONYMOUSrical unconditional and conditional entropy of the sampled strings.\n",
    "\n",
    "## Experiment\n",
    "\n",
    "We use an ngram length of $n = 1$, a relative probability with $k = 16$ and strings of length $1024$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94a0c87c-70a0-4ec1-82b6-c472cecdf857",
   "metadata": {},
   "outputs": [],
   "source": [
    "ALPHABET_SIZES = [2, 7, 26]\n",
    "LEGEND_TITLE = \"Alphabet Size\"\n",
    "VARIATION_DIMEANONYMOUSION = \"alphabet-size\"\n",
    "\n",
    "def load_results(\n",
    "    model: tuple[str, str],\n",
    "    seed_ids: list[int],\n",
    "    alphabet_sizes: list[int] = ALPHABET_SIZES,\n",
    ") -> tuple[dict, dict]:\n",
    "    model_short, model_long = model\n",
    "    \n",
    "    mem_results = {}\n",
    "    prefix_mapping_results = {}\n",
    "    for alphabet_size in alphabet_sizes:\n",
    "        res_name = f\"l = {alphabet_size}\"\n",
    "        mem_res = mem_res_util.load(\n",
    "            [f\"rand_a-lat-{alphabet_size}_t-1024_ngl-1\", model_long],\n",
    "            seed_ids,\n",
    "        )\n",
    "        if len(mem_res) > 0:\n",
    "            mem_results[res_name] = mem_res\n",
    "        pm_res = res_util.load(\n",
    "            [\"conditional_probability\", f\"{model_short}_a-{alphabet_size}_rp-16_n-1\"],\n",
    "            seed_ids,\n",
    "        )\n",
    "        if len(pm_res) > 0:\n",
    "            prefix_mapping_results[res_name] = pm_res\n",
    "    return mem_results, prefix_mapping_results\n",
    "\n",
    "def show(results: tuple[dict, dict]):\n",
    "    mem_results, pm_results = results\n",
    "    res_util.show_memorization_dynamics(\n",
    "        mem_results,\n",
    "        LEGEND_TITLE,\n",
    "    )\n",
    "    figs = res_util.show_epoch_prefix_lengths(\n",
    "         pm_results,\n",
    "    )\n",
    "    return figs\n",
    "\n",
    "def produce_paper_plots(\n",
    "    figures: dict,\n",
    "    model: str,\n",
    "):\n",
    "    for i, (config_name, config_fig) in enumerate(figures.items()):\n",
    "        # if config_name == \"l = 2\":\n",
    "        #     continue\n",
    "        res_util.produce_epoch_prefix_paper_plot(\n",
    "            config_fig,\n",
    "            model,\n",
    "            variation_dimension=f\"{VARIATION_DIMEANONYMOUSION}_a-{config_name[4:]}\",\n",
    "            show_legend=i == 4,\n",
    "            figure_folder=\"prefix_mappings/prefix_lengths/conditional_probs\",\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "45c58ed0-30cb-4902-bd64-ccf912cb0309",
   "metadata": {},
   "source": [
    "## Pythia-1B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe49afea-c0f5-4129-81ac-8698a506fed7",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_ids = list(range(5))\n",
    "results = results = load_results(\n",
    "    (\"pyt-1b\", \"pythia-1b\"),\n",
    "    seed_ids,\n",
    ")   \n",
    "pyt_1b_figures = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8bce04a-a18f-4fd8-9b8d-2510bea2bd9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    pyt_1b_figures,\n",
    "    \"pythia-1b\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "062504d2-163a-4bd0-839c-7888760a418b",
   "metadata": {},
   "source": [
    "## Phi-2.7B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de2be29e-a51d-416e-8a23-77ec9f1784ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_ids = list(range(3))\n",
    "results = results = load_results(\n",
    "    (\"phi-2.7b\", \"phi-2\"),\n",
    "    seed_ids,\n",
    ")   \n",
    "phi_2_figures = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ffd892a8-5e8a-4d40-9fd1-19fa36ad9ae0",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    phi_2_figures,\n",
    "    \"phi-2.7b\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a7662c84-6a16-4fe0-ae69-72ad000980d1",
   "metadata": {},
   "source": [
    "## Llama2-13B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58aca706-3901-445b-b137-cd1a15b57d4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_ids = list(range(2))\n",
    "results = results = load_results(\n",
    "    (\"llama2-13b\", \"Llama-2-13b-hf\"),\n",
    "    seed_ids,\n",
    ")   \n",
    "llama_13b_figures = show(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3184fa74-fc02-48c3-925d-110a7696eb4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    llama_13b_figures,\n",
    "    \"llama2-13b\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "270f5680-9e91-4555-8b67-2124bb37d472",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3341cec-0464-4638-a04a-ba103bad739c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Upload\n",
    "res_util.publish(\"prefix_length/conditional_probabilities\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6dd02a39-ad0f-4b8c-8fd9-510491ae507b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
