{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e355b99d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from lib_project.notebook import setup_notebook\n",
    "setup_notebook(\"../../../\")\n",
    "               \n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2815c25-e436-4aa4-afd1-c21374e2eb75",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from functools import partial\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from IPython.display import Markdown as md\n",
    "\n",
    "from defs import MODELS\n",
    "from experiments.conditional_prob_mem_dynamics import results as res_util\n",
    "from experiments.memorization_dynamics import results as md_res_util"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "519d2453-249b-448b-9cf0-8e6aeab00929",
   "metadata": {},
   "outputs": [],
   "source": [
    "display(md(\"./experiments/conditional_prob_mem_dynamics/README.md\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bc4d334f-b6cc-4865-8ccb-617e2d8782d9",
   "metadata": {},
   "source": [
    "# Experiments on changing the Condition Length/n-gram Size $n$\n",
    "\n",
    "We want to determine the effect of the condition length, i.e. $n$, of conditional entropy $H_n(s)$ on memorization difficulty.\n",
    "\n",
    "We fix a relative probability level $k = 16$ and train and evaluate models on strings with different condition/n-gram lengths $n$.\n",
    "We use $n \\in \\{0, 1, 2, 3, 4\\}$.\n",
    "For $n = 0$ the unconditional entropy is not modified.\n",
    "\n",
    "## Takeaways\n",
    "\n",
    "Varying the condition length $n$ does not seem to impact memorization speed/difficulty.\n",
    "All strings are fully memorized roughly at the same epoch, with no consistent relationship between condition length $n$ and full memorization epoch.\n",
    "\n",
    "For larger $n$ (esp. 3 and 4), models seem not to be able to learn the conditional distribution, especially for larger alphabets (7 and 26).\n",
    "This may be because for larger $n$ and larger alphabets, specific n-grams become very sparse in a 1024 token string, so models might not be able to infer the distribution.\n",
    "The eANONYMOUSrical entropy tables show that even though the conditional entropy should be the same for different $n$, it strongly decreases for $l = 7, 26$ for larger $n$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a35828d-0e51-4f7c-9f96-e2e20545394e",
   "metadata": {},
   "outputs": [],
   "source": [
    "VARIATIOANONYMOUS = [1, 2, 3, 4]\n",
    "LEGEND_TITLE = \"ngram length\"\n",
    "ALPHABET_SIZES = [2, 7, 26]\n",
    "\n",
    "\n",
    "def load_results(\n",
    "    model: str,\n",
    "    alphabet_size: int,\n",
    "    relative_probability: int,\n",
    "    seed_ids: list[int],\n",
    "    variations: list[int] = VARIATIOANONYMOUS,\n",
    ") -> dict:\n",
    "    results = {\n",
    "        \"n = 0\": md_res_util.load(\n",
    "            [f\"rand_a-lat-{alphabet_size}_t-1024\", MODELS[model][0]],\n",
    "            seed_ids,\n",
    "        )\n",
    "    } | {\n",
    "        f\"n = {val}\": res_util.load(\n",
    "            f\"{model}_a-{alphabet_size}_t-1024_rp-{relative_probability}_n-{val}\",\n",
    "            seed_ids,\n",
    "        )\n",
    "        for val in variations\n",
    "    }\n",
    "    return results\n",
    "\n",
    "def show_results(results: dict):\n",
    "    return res_util.show_dynamics(\n",
    "        results,\n",
    "        LEGEND_TITLE,\n",
    "        show_kld=False,\n",
    "        show_in_context_learning=False,\n",
    "        show_cum_prob=False,\n",
    "        show_entropy=False,\n",
    "    )\n",
    "\n",
    "def show_multi_alphabet_results(\n",
    "    model: str,\n",
    "    relative_probability: int,\n",
    "    seed_ids: list[int],\n",
    "    variations: list[int] = VARIATIOANONYMOUS,\n",
    "    alphabet_sizes: list[int] = ALPHABET_SIZES,\n",
    "    data_example_for: int | None = 2,\n",
    "):\n",
    "    for alphabet_size in alphabet_sizes:\n",
    "        display(md(f\"### Alphabet Size $l = {alphabet_size}$\"))\n",
    "        results = load_results(\n",
    "            model,\n",
    "            alphabet_size,\n",
    "            relative_probability,\n",
    "            seed_ids,\n",
    "            variations=variations,\n",
    "        )\n",
    "        res_util.show_eANONYMOUSrical_entropies(\n",
    "            results,\n",
    "            [0] + variations,\n",
    "            \"Condition length (n)\",\n",
    "            [f\"n = {n}\" for n in [0] + variations],\n",
    "        )\n",
    "        if data_example_for is not None:\n",
    "            data_example = \"\".join(\n",
    "                results[f\"n = {data_example_for}\"][0].value.data.tokens[0]\n",
    "            )\n",
    "            print(\"Example string for ngram length n = 1:\\n\", data_example)\n",
    "        _ = show_results(results)\n",
    "\n",
    "# produce_paper_plots = partial(\n",
    "#     res_util.produce_paper_plots,\n",
    "#     figure_folder=\"mem_dynamics_rw_validation/batch_size\",\n",
    "# )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3942594f-5caf-4e14-92ac-9c3be7cd91f9",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Pythia-1B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39609c78-b71c-4360-a271-0202b5475e0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_multi_alphabet_results(\n",
    "    \"pyt-1b\",\n",
    "    relative_probability=16,\n",
    "    seed_ids=list(range(3)),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6756e93b-79e3-42a1-ba9f-5f1ce8d7aad0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9af2c0a1-fa48-4555-8d4f-d759b5450f52",
   "metadata": {},
   "outputs": [],
   "source": [
    "# seed_ids = list(range(1))\n",
    "# results = results = load_results(\n",
    "#     \"pyt-1b\",\n",
    "#     26,\n",
    "#     relative_probability=16,\n",
    "#     seed_ids=seed_ids,\n",
    "# )   \n",
    "# pyt_1b_figures = show_results(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ef11fc8-5893-48d6-8df0-4f55984326e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# \"\".join(results[\"n = 3\"][0].value.random_data.tokens[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5789a54f-fe86-488e-9f3c-d4932ed80c47",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Llama2-13B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf8673f9-8c18-473d-80dd-d92398790c79",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "seed_ids = list(range(1))\n",
    "results = load_results(\n",
    "    \"lama2-13b\",\n",
    "    26,\n",
    "    seed_ids,\n",
    ")   \n",
    "llama2_13b_figures = show(results)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c61392a4-cad8-412d-bd03-ed525213a7b7",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Phi-2 (2.7B parameters)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "462bbf7d-d6e4-4bdc-8c11-b692de1df638",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_ids = list(range(3))\n",
    "show_multi_results(\n",
    "    # produce_paper_plots,\n",
    "    None,\n",
    "    \"phi-2.7b\",\n",
    "    seed_ids,\n",
    "    batch_sizes=[1, 4, 16],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94f0bc98-5ace-4c16-b155-19b34e145299",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3341cec-0464-4638-a04a-ba103bad739c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Upload\n",
    "res_util.publish(\"ngram_length\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de2b4111-16a1-4c26-aa8a-7d6241d966a1",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
