{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e355b99d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from lib_project.notebook import setup_notebook\n",
    "setup_notebook(\"../../../\")\n",
    "               \n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acbd71c6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from IPython.display import Markdown as md\n",
    "\n",
    "from lib_project.visualization import with_paper_style\n",
    "from defs import BASE_FIGURE_DIR\n",
    "from experiments.memorization_dynamics import results as res_util\n",
    "from utils import memorization_order as mem_ord"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e1206ca3-0050-4a98-9e42-355d6aeaab2c",
   "metadata": {
    "tags": []
   },
   "source": [
    "# Relationship between memorization order and conflicting prefixes\n",
    "\n",
    "We want to see whether there is a relationship between the agreeing/disagreeing prefixes of a token and its memorization order.\n",
    "A token `s[i]` at position $i$ in string $s$ has an agreeing prefix of length $k$ with another token `s[j]` if `s[i-k:i] == s[j-k:j]` and `s[i] == s[j]`, and a disagreeing prefix with $j$ if `s[i] != s[j]`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2855ef66-c474-4c7d-8d46-2662403f57e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "sids = [0]\n",
    "results = {\n",
    "    \"Pythia 1B\": res_util.load([\"rand_a-lat-26_t-1024\", \"pythia-1b\"], sids)\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59cacab2-f632-442a-96fe-34a70376fea7",
   "metadata": {},
   "outputs": [],
   "source": [
    "result = results[\"Pythia 1B\"][0].value\n",
    "mem_res = result.memorization_log\n",
    "# mem_res.head(30)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6b623594-b182-4410-a961-624b1b7069f4",
   "metadata": {},
   "source": [
    "## Stable memorization order\n",
    "\n",
    "We investigate the relationship between prefix agreement and stable memorization order.\n",
    "The stable memorization epoch $t$ of a token `s[i]` is the epoch when it is first predicted correctly, without being subsequently predicted incorrectly again, i.e. when its prediction can be considered converged."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a618960-14e9-42ac-9f5f-3b65407eb776",
   "metadata": {},
   "outputs": [],
   "source": [
    "stable_order = mem_ord.stable_memorization_order(mem_res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c749245f-8cb0-4175-9d48-606f32f1b2c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "for prefix_length in range(1, 5):\n",
    "    display(md(f\"### Prefix length {prefix_length} agreement/disagreement\"))\n",
    "    prefix_agreement = mem_ord.prefix_agreement(mem_res, prefix_length=prefix_length)\n",
    "    mem_ord.plot_order_agreement_disagreement(stable_order, prefix_agreement).show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8a2096bb-d833-47f3-95e7-82b704b2475f",
   "metadata": {},
   "source": [
    "## Initial-memorization order\n",
    "\n",
    "We investigate the relationship between prefix agreement and initial memorization order.\n",
    "The initial memorization epoch $t$ of a token `s[i]` is the epoch when it is first predicted correctly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a32ec5dd-ec7b-4ec0-ba32-4f2545a543bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "initial_order = mem_ord.initial_memorization_order(mem_res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d86d2254-a2d5-4803-9ff6-b193c7d6f5ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "for prefix_length in range(1, 5):\n",
    "    display(md(f\"### Prefix length {prefix_length} agreement/disagreement\"))\n",
    "    prefix_agreement = mem_ord.prefix_agreement(mem_res, prefix_length=prefix_length)\n",
    "    mem_ord.plot_order_agreement_disagreement(initial_order, prefix_agreement).show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f05743f6-d649-4dcd-9524-de331c13ec3c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3341cec-0464-4638-a04a-ba103bad739c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Upload\n",
    "res_util.publish(\"order_analysis/prefix_order_rel\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de626a7a-9492-4fe0-ad94-b525e91d1c84",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
