{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e355b99d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from lib_dl.analysis.notebook import setup_notebook\n",
    "setup_notebook(\"../../../\")\n",
    "               \n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acbd71c6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "\n",
    "from IPython.display import display, Markdown as md\n",
    "import pandas as pd\n",
    "\n",
    "from defs import RANDOM_STRING_MODEL_DIR\n",
    "from utils.data.random_strings import (\n",
    "    RandomStringConfig,\n",
    "    get_random_strings,\n",
    ")\n",
    "from utils.finetuning.finetune import (\n",
    "    load_log,\n",
    ")\n",
    "\n",
    "from experiments.model_training import results as res_util"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b92cc8f4-3bc0-4cd0-bb7d-8bcfce60739c",
   "metadata": {},
   "source": [
    "## Generating data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d9e3e60-dd13-4d25-98d2-88b2b1bfb82a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "conf = RandomStringConfig(\n",
    "    seed_id=1,\n",
    "    num_tokens=16,\n",
    "    num_partitions=1,\n",
    "    alphabet_size=7,\n",
    "    artifacts_dir=\"/home/exp/artifacts\",\n",
    ")\n",
    "get_random_strings(conf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95601499-9930-4365-a793-57b439cf518b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "conf.name"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5cdf650b-bc3e-4974-8daa-4e3d050bba85",
   "metadata": {},
   "source": [
    "### Generate all required data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6309e569-5bab-466d-b918-e48a85f6b100",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "for seed_id in list(range(10)):\n",
    "    for alphabet_size in [2, 4, 7, 13, 26]:\n",
    "        for num_tokens in [16, 32, 64, 128, 256, 512, 1024]:\n",
    "            conf = RandomStringConfig(\n",
    "                seed_id,\n",
    "                num_tokens=num_tokens,\n",
    "                alphabet_size=alphabet_size,\n",
    "                artifacts_dir=\"/home/exp/artifacts\",\n",
    "            )\n",
    "            get_random_strings(conf)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "77b4d948-c576-4edf-81d0-bedf7f21c9f8",
   "metadata": {},
   "source": [
    "### Generate the data with different first char probs to match entropy level"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa9fec23-c322-4f7a-a8dc-e683f14b87f8",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Probability values for the first character, that reduce the entropy\n",
    "# of the 26 character alphabet down to that of uniform lower alphabet\n",
    "# size ones.\n",
    "first_char_probs = {\n",
    "    13: 0.41385763230705264,\n",
    "    7: 0.6040332620868685,\n",
    "    4: 0.7455315447635649,\n",
    "    2: 0.8913995544185638,\n",
    "}\n",
    "for target_alphabet_size, fcp in first_char_probs.items():\n",
    "    for seed_id in list(range(10)):\n",
    "        conf = RandomStringConfig(\n",
    "            seed_id,\n",
    "            num_tokens=1024,\n",
    "            alphabet_size=26,\n",
    "            first_char_prob=fcp,\n",
    "            artifacts_dir=\"/home/exp/artifacts\",\n",
    "        )\n",
    "        get_random_strings(conf)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d32c6649-c0af-4aff-b2b0-25dcf2cd90b1",
   "metadata": {},
   "source": [
    "## Checking memoriztaion logs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55279102-19c2-41eb-9523-d39be1aa0a3d",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_tokens = 128\n",
    "data_id = f\"sid-0_a-26_t-{num_tokens}_p-1\"\n",
    "log_path = f\"../artifacts/random_strings/models/pythia-1b/{data_id}/epoch_100/logs/memorization.parquet\"\n",
    "mem_log = pd.read_parquet(log_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71be227b-be47-4115-9590-7e45a277f209",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "mem_log"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7c763bd4-3ac7-4a35-97c4-c2eb864d4d47",
   "metadata": {},
   "source": [
    "### Loss log"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3f65c6f-6da2-4c26-959a-816d12fe4bff",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "num_tokens = 128\n",
    "data_id = f\"sid-0_a-26_t-{num_tokens}_p-1\"\n",
    "log_path = f\"../artifacts/random_strings/models/pythia-1b/{data_id}/epoch_100/logs/loss.parquet\"\n",
    "pd.read_parquet(log_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bb8b19f7-ca1a-47c8-ab63-0bcadf974ea8",
   "metadata": {},
   "source": [
    "### Test config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2cd2d1de-8fe5-4453-ac8d-3b41c7e02171",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_id = \"sid-0_a-7_t-64_p-1\"\n",
    "log_path = f\"../artifacts/random_strings/models/pythia-70m/{data_id}/epoch_1/logs/memorization.parquet\"\n",
    "pd.read_parquet(log_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60d2a19b-2281-42f2-9739-e626640a42db",
   "metadata": {},
   "source": [
    "## Compute first character probability for desired entropy level"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d6014f7-b0eb-4499-aab1-31deeb7cb07b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from functools import partial"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "814f1a25-f091-4a13-a948-73738ecf5470",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def entropy(alphabet_size: int, first_char_prob: int | None = None) -> float:\n",
    "    if first_char_prob is None:\n",
    "        first_char_prob = 1 / alphabet_size\n",
    "    remaining_char_prob = (1 - first_char_prob) / (alphabet_size - 1)\n",
    "    return -(\n",
    "        first_char_prob * np.log(first_char_prob)\n",
    "        + (alphabet_size - 1) * (remaining_char_prob * np.log(remaining_char_prob))\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "589a78cf-456d-43ea-88fe-88d9d5107631",
   "metadata": {},
   "outputs": [],
   "source": [
    "def bisection_method(error_func, a, b, tol=1e-6, max_iter=1000):\n",
    "    iter_count = 0\n",
    "    while iter_count < max_iter:\n",
    "        c = (a + b) / 2.0\n",
    "        if error_func(c) == 0 or (b - a) / 2 < tol:\n",
    "            return c\n",
    "        iter_count += 1\n",
    "        if error_func(c) * error_func(a) < 0:\n",
    "            b = c\n",
    "        else:\n",
    "            a = c\n",
    "    return (a + b) / 2.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "514d941b-85e2-402a-9ecb-7918a899dc8a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "a715f53b-c054-4714-8202-14fbf2cbac2d",
   "metadata": {},
   "source": [
    "### Default entropies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8e1e522-ac25-4eee-bf93-b1f3ad53fd1e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "for alphabet_size in [2, 4, 7, 13, 26]:\n",
    "    print(f\"Alphabet size {alphabet_size} entropy:\", entropy(alphabet_size))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f6ab234-58a5-4b34-a9b1-006a49f5e14e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "for alphabet_size in [2, 4, 7, 13]:\n",
    "    target_entropy = entropy(alphabet_size)\n",
    "    print(f\"Alphabet {alphabet_size} target entropy:\", target_entropy)\n",
    "    error_func = lambda p: target_entropy - entropy(26, p)\n",
    "\n",
    "    # Initial interval [a, b] where a and b are initial guesses for the root\n",
    "    a = 0 + 1e-6  # to avoid log(0)\n",
    "    b = 1 - 1e-6  # p should be in the interval (0,1)\n",
    "    p_solution = bisection_method(error_func, a, b)\n",
    "    print(\"solution prob:\", p_solution)\n",
    "    solution_entropy = entropy(26, p_solution)\n",
    "    print(\"Entropy with solution prob:\", solution_entropy)\n",
    "    print(\"Entropy difference:\", target_entropy - solution_entropy)\n",
    "    print(\"\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7951ae6c-008c-4d40-855a-4ff111437350",
   "metadata": {},
   "source": [
    "## Results test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5036bd53-1320-447f-9337-b0a6f3ba7987",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "results = res_util.set_base_storage_dir({\n",
    "    f\"{num_tokens} Tokens\": res_util.load(f\"pyt-70m_a-26_t-{num_tokens}_p-8\", list(range(1)))\n",
    "    for num_tokens in [1024]\n",
    "}, RANDOM_STRING_MODEL_DIR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95e7a80e-4188-4e3a-b455-292176afbbc6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "results[\"1024 Tokens\"][0].value.memorization_log"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41d46239-8f3b-4e20-be7f-1d83202dbe17",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "figures = res_util.show_results_overview(\n",
    "    results,\n",
    "    \"String Length\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e255ec89-45b6-4780-bba6-a4edd7f37631",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
