{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e355b99d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from lib_project.notebook import setup_notebook\n",
    "setup_notebook(\"../../../\")\n",
    "               \n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acbd71c6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from functools import partial\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from IPython.display import Markdown as md\n",
    "\n",
    "from defs import MODELS\n",
    "from experiments.conditional_prob_mem_dynamics import results as res_util\n",
    "from experiments.memorization_dynamics import results as md_res_util"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "788b4d10-082f-4d0a-a06e-e85b765be39f",
   "metadata": {},
   "outputs": [],
   "source": [
    "display(md(\"./experiments/conditional_prob_mem_dynamics/README.md\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e1206ca3-0050-4a98-9e42-355d6aeaab2c",
   "metadata": {
    "tags": []
   },
   "source": [
    "# Experiments on changing the Relative Probability $p_k$\n",
    "\n",
    "We want to determine whether at the same level of unconditional entropy $H(s)$, strings with different conditional entropies $H_n(s)$ differ in how hard they are to memorize by the models.\n",
    "\n",
    "We fix $n = 1$ and train and evaluate models on strings with different relative probabilities $p_k$, i.e. where the priviledged continuation tokens are $k$ times as likely to appear after their n-grams than the ramining tokens from $A$.\n",
    "We use $k \\in \\{1, 2, 4, 8, 16, 32, 64\\}$.\n",
    "For $k = 1$ the unconditional entropy $H(s)$ is the same as the conditional entropy $H_1(s) = H(s)$.\n",
    "\n",
    "## Takeaways\n",
    "\n",
    "The conditional entropy of strings does not seem to affect their memorization difficulty.\n",
    "While it changes the plateau-level of the Guessing-Phase, all strings are fully memorized roughly at the same epoch, with no consistent relationship between relative probability and full memorization eopch.\n",
    "The effect of unconditional entropy on the other hand is much stronger.\n",
    "\n",
    "The conditional entropy changes the starting point of the loss and accuracy curves.\n",
    "Strings with lower conditional entropy, i.e. higher relative probability, start with lower loss/higher accuracy.\n",
    "Presumably this happens because models can use in-context-learning to initially predict more of the tokens correctly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a35828d-0e51-4f7c-9f96-e2e20545394e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# VARIATIOANONYMOUS = [2, 4, 8, 16, 32, 64]\n",
    "VARIATIOANONYMOUS = [4, 16, 64]\n",
    "LEGEND_TITLE = \"Relative Probability\"\n",
    "ALPHABET_SIZES = [2, 7, 26]\n",
    "\n",
    "def load_results(\n",
    "    model: str,\n",
    "    alphabet_size: int,\n",
    "    ngram_size: int,\n",
    "    seed_ids: list[int],\n",
    "    variations: list[int] = VARIATIOANONYMOUS,\n",
    ") -> dict:\n",
    "    results = {\n",
    "        \"k = 1\": md_res_util.load(\n",
    "            [f\"rand_a-lat-{alphabet_size}_t-1024\", MODELS[model][0]],\n",
    "            seed_ids,\n",
    "        )\n",
    "    } | {\n",
    "        f\"k = {val}\": res_util.load(\n",
    "            f\"{model}_a-{alphabet_size}_t-1024_rp-{val}_n-{ngram_size}\",\n",
    "            seed_ids,\n",
    "        )\n",
    "        for val in variations\n",
    "    }\n",
    "    return results\n",
    "\n",
    "def show_results(results: dict):\n",
    "    return res_util.show_dynamics(\n",
    "        results,\n",
    "        LEGEND_TITLE,\n",
    "        show_kld=False,\n",
    "        show_in_context_learning=False,\n",
    "        show_cum_prob=False,\n",
    "        show_entropy=False,\n",
    "    )\n",
    "\n",
    "def show_multi_alphabet_results(\n",
    "    model: str,\n",
    "    ngram_size: int,\n",
    "    seed_ids: list[int],\n",
    "    variations: list[int] = VARIATIOANONYMOUS,\n",
    "    alphabet_sizes: list[int] = ALPHABET_SIZES,\n",
    "    data_example_for: int | None = 16,\n",
    "):\n",
    "    figures = {}\n",
    "    for alphabet_size in alphabet_sizes:\n",
    "        display(md(f\"### Alphabet Size $l = {alphabet_size}$\"))\n",
    "        results = load_results(\n",
    "            model,\n",
    "            alphabet_size,\n",
    "            ngram_size,\n",
    "            seed_ids,\n",
    "            variations=variations,\n",
    "        )\n",
    "        priviledged_token_probs = {\n",
    "            1: 1 / alphabet_size,\n",
    "        } | {\n",
    "            rel_prob: results[f\"k = {rel_prob}\"][0].config.data.first_char_prob\n",
    "            for rel_prob in variations\n",
    "        }\n",
    "        print(\"Relative probabilities: \" + \", \".join(\n",
    "            f\"k = {rel_prob}: {prob:.3f}\" for rel_prob, prob in\n",
    "            priviledged_token_probs.items()\n",
    "        ))\n",
    "        res_util.show_eANONYMOUSrical_entropies(\n",
    "            results,\n",
    "            ngram_size,\n",
    "            \"Relative Prob\",\n",
    "            [f\"k = {k}\" for k in [1] + variations],\n",
    "        )\n",
    "        if data_example_for is not None:\n",
    "            data_example = \"\".join(\n",
    "                results[f\"k = {data_example_for}\"][0].value.data.tokens[0]\n",
    "            )\n",
    "            print(\"Example string for relative probability x16:\\n\", data_example)\n",
    "        figures[f\"a-{alphabet_size}\"] = show_results(results)\n",
    "    return figures\n",
    "\n",
    "produce_paper_plots = partial(\n",
    "    res_util.produce_paper_plots,\n",
    "    figure_folder=\"memorability/conditional_probability\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3942594f-5caf-4e14-92ac-9c3be7cd91f9",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Pythia-1B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fbbfd8f7-cefe-4e88-9777-a189576d91c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "pyt_figures = show_multi_alphabet_results(\n",
    "    \"pyt-1b\",\n",
    "    ngram_size=1,\n",
    "    seed_ids=list(range(5)),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a8dec1c-4aa2-4147-a1c2-b0c60b656ea8",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    pyt_figures,\n",
    "    \"pythia-1b\",\n",
    "    \"rel-prob\",\n",
    "    show_legend=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5789a54f-fe86-488e-9f3c-d4932ed80c47",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Llama2-13B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf8673f9-8c18-473d-80dd-d92398790c79",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "llama_figures = show_multi_alphabet_results(\n",
    "    \"llama2-13b\",\n",
    "    ngram_size=1,\n",
    "    seed_ids=list(range(5)),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d203189d-0fc3-4f79-a007-64a2925c9474",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    llama_figures,\n",
    "    \"llama2-13b\",\n",
    "    \"rel-prob\",\n",
    "    show_legend=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c61392a4-cad8-412d-bd03-ed525213a7b7",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Phi-2 (2.7B parameters)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "462bbf7d-d6e4-4bdc-8c11-b692de1df638",
   "metadata": {},
   "outputs": [],
   "source": [
    "phi_figures = show_multi_alphabet_results(\n",
    "    \"phi-2.7b\",\n",
    "    ngram_size=1,\n",
    "    seed_ids=list(range(5)),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d412bde-418d-4ce3-874f-fdf30d5e25dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "produce_paper_plots(\n",
    "    phi_figures,\n",
    "    \"phi-2.7b\",\n",
    "    \"rel-prob\",\n",
    "    show_legend=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94f0bc98-5ace-4c16-b155-19b34e145299",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3341cec-0464-4638-a04a-ba103bad739c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Upload\n",
    "res_util.publish(\"relative_probability\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de626a7a-9492-4fe0-ad94-b525e91d1c84",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
