{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e355b99d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from lib_project.notebook import setup_notebook\n",
    "setup_notebook(\"../../../\")\n",
    "               \n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acbd71c6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "\n",
    "from IPython.display import display, Markdown as md\n",
    "import pandas as pd\n",
    "\n",
    "from lib_llm.models.load import (\n",
    "    ModelConfig,\n",
    "    load_tokenizer,\n",
    ")\n",
    "from defs import RANDOM_STRING_MODEL_DIR\n",
    "from data.synthetic_strings.random import (\n",
    "    RandomStringConfig,\n",
    "    generate_random_strings,\n",
    ")\n",
    "from data.synthetic_strings.deterministic_rules import (\n",
    "    DeterministicRuleStringConfig,\n",
    "    generate_deterministic_rule_strings,\n",
    ")\n",
    "\n",
    "# from experiments.model_training import results as res_util"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b92cc8f4-3bc0-4cd0-bb7d-8bcfce60739c",
   "metadata": {},
   "source": [
    "## Generating random string data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58bd6282-8f1c-4da2-8b38-0fc209a800fa",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "tokenizer_type = \"pythia\"\n",
    "tokenizer = load_tokenizer(\n",
    "    ModelConfig(\n",
    "        model_id=\"EleutherAI/pythia-70m\",\n",
    "        # base_dir=\"/home/exp/base_models\",\n",
    "    ),\n",
    "    max_length=2048,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d9e3e60-dd13-4d25-98d2-88b2b1bfb82a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "conf = RandomStringConfig(\n",
    "    seed_id=1,\n",
    "    num_tokens=16,\n",
    "    tokenizer_type=tokenizer_type,\n",
    "    num_partitions=2,\n",
    "    alphabet_size=7,\n",
    ")\n",
    "generate_random_strings(conf, tokenizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3178d5bd-bc68-4137-983d-0f9689182eb2",
   "metadata": {},
   "source": [
    "## Generating deterministic rule-based string data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34cf9a91-940c-4817-b32d-408d180f9064",
   "metadata": {},
   "outputs": [],
   "source": [
    "conf = DeterministicRuleStringConfig(\n",
    "    seed_id=0,\n",
    "    string_length=16,\n",
    "    num_strings=2,\n",
    "    premise_length=2,\n",
    "    tokenizer_type=tokenizer_type,\n",
    "    alphabet_size=7,\n",
    ")\n",
    "generate_deterministic_rule_strings(conf, tokenizer).raw_token_ids"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d32c6649-c0af-4aff-b2b0-25dcf2cd90b1",
   "metadata": {},
   "source": [
    "## Checking memoriztaion logs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55279102-19c2-41eb-9523-d39be1aa0a3d",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_tokens = 128\n",
    "data_id = f\"sid-0_a-26_t-{num_tokens}_p-1\"\n",
    "log_path = f\"../artifacts/random_strings/models/pythia-1b/{data_id}/epoch_100/logs/memorization.parquet\"\n",
    "# log_path = f\"/ANONYMOUS/llm-1/work/ANONYMOUS/llms/artifacts/random_strings/models/pythia-1b/{data_id}/epoch_100/logs/memorization.parquet\"\n",
    "mem_log = pd.read_parquet(log_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71be227b-be47-4115-9590-7e45a277f209",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "mem_log"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7c763bd4-3ac7-4a35-97c4-c2eb864d4d47",
   "metadata": {},
   "source": [
    "### Loss log"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3f65c6f-6da2-4c26-959a-816d12fe4bff",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "num_tokens = 128\n",
    "data_id = f\"sid-0_a-26_t-{num_tokens}_p-1\"\n",
    "log_path = f\"../artifacts/random_strings/models/pythia-1b/{data_id}/epoch_100/logs/loss.parquet\"\n",
    "# log_path = f\"/ANONYMOUS/llm-1/work/ANONYMOUS/llms/artifacts/random_strings/models/pythia-1b/{data_id}/epoch_100/logs/memorization.parquet\"\n",
    "pd.read_parquet(log_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bb8b19f7-ca1a-47c8-ab63-0bcadf974ea8",
   "metadata": {},
   "source": [
    "### Test config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2cd2d1de-8fe5-4453-ac8d-3b41c7e02171",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_id = \"sid-0_a-7_t-64_p-1\"\n",
    "log_path = f\"../artifacts/random_strings/models/pythia-70m/{data_id}/epoch_1/logs/memorization.parquet\"\n",
    "pd.read_parquet(log_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60d2a19b-2281-42f2-9739-e626640a42db",
   "metadata": {},
   "source": [
    "## Compute first character probability for desired entropy level"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d6014f7-b0eb-4499-aab1-31deeb7cb07b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from functools import partial"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "814f1a25-f091-4a13-a948-73738ecf5470",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def entropy(alphabet_size: int, first_char_prob: int | None = None) -> float:\n",
    "    if first_char_prob is None:\n",
    "        first_char_prob = 1 / alphabet_size\n",
    "    remaining_char_prob = (1 - first_char_prob) / (alphabet_size - 1)\n",
    "    return -(\n",
    "        first_char_prob * np.log(first_char_prob)\n",
    "        + (alphabet_size - 1) * (remaining_char_prob * np.log(remaining_char_prob))\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "589a78cf-456d-43ea-88fe-88d9d5107631",
   "metadata": {},
   "outputs": [],
   "source": [
    "def bisection_method(error_func, a, b, tol=1e-6, max_iter=1000):\n",
    "    iter_count = 0\n",
    "    while iter_count < max_iter:\n",
    "        c = (a + b) / 2.0\n",
    "        if error_func(c) == 0 or (b - a) / 2 < tol:\n",
    "            return c\n",
    "        iter_count += 1\n",
    "        if error_func(c) * error_func(a) < 0:\n",
    "            b = c\n",
    "        else:\n",
    "            a = c\n",
    "    return (a + b) / 2.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "514d941b-85e2-402a-9ecb-7918a899dc8a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "a715f53b-c054-4714-8202-14fbf2cbac2d",
   "metadata": {},
   "source": [
    "### Default entropies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8e1e522-ac25-4eee-bf93-b1f3ad53fd1e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "for alphabet_size in [2, 4, 7, 13, 26]:\n",
    "    print(f\"Alphabet size {alphabet_size} entropy:\", entropy(alphabet_size))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f6ab234-58a5-4b34-a9b1-006a49f5e14e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "for alphabet_size in [2, 4, 7, 13]:\n",
    "    target_entropy = entropy(alphabet_size)\n",
    "    print(f\"Alphabet {alphabet_size} target entropy:\", target_entropy)\n",
    "    error_func = lambda p: target_entropy - entropy(26, p)\n",
    "\n",
    "    # Initial interval [a, b] where a and b are initial guesses for the root\n",
    "    a = 0 + 1e-6  # to avoid log(0)\n",
    "    b = 1 - 1e-6  # p should be in the interval (0,1)\n",
    "    p_solution = bisection_method(error_func, a, b)\n",
    "    print(\"solution prob:\", p_solution)\n",
    "    solution_entropy = entropy(26, p_solution)\n",
    "    print(\"Entropy with solution prob:\", solution_entropy)\n",
    "    print(\"Entropy difference:\", target_entropy - solution_entropy)\n",
    "    print(\"\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7951ae6c-008c-4d40-855a-4ff111437350",
   "metadata": {},
   "source": [
    "## Results test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5036bd53-1320-447f-9337-b0a6f3ba7987",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "results = res_util.set_base_storage_dir({\n",
    "    f\"{num_tokens} Tokens\": res_util.load(f\"pyt-70m_a-26_t-{num_tokens}_p-8\", list(range(1)))\n",
    "    for num_tokens in [1024]\n",
    "}, RANDOM_STRING_MODEL_DIR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95e7a80e-4188-4e3a-b455-292176afbbc6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "results[\"1024 Tokens\"][0].value.memorization_log"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41d46239-8f3b-4e20-be7f-1d83202dbe17",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "figures = res_util.show_results_overview(\n",
    "    results,\n",
    "    \"String Length\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a246a008-48c1-4c02-8d05-92e1afe4cc3e",
   "metadata": {},
   "source": [
    "## Llama-2 tokenization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b11807f6-7db6-401d-9539-e18865b9ddb7",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# tokenizer_type = \"llama2\"\n",
    "# tokenizer = load_tokenizer(\n",
    "#     ModelConfig(\n",
    "#         model_id=\"Llama-2-7b-hf\",\n",
    "#         base_dir=\"/home/exp/base_models\",\n",
    "#     ),\n",
    "#     max_length=2048,\n",
    "# )\n",
    "\n",
    "tokenizer_type = \"gpt2\"\n",
    "tokenizer = load_tokenizer(\n",
    "    ModelConfig(\n",
    "        model_id=\"gpt2\",\n",
    "        base_dir=None,\n",
    "    ),\n",
    "    max_length=2048,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd620d01-e08c-4e15-a898-69d90d4ece23",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer_type = \"pythia\"\n",
    "tokenizer = load_tokenizer(\n",
    "    ModelConfig(\n",
    "        model_id=\"EleutherAI/pythia-70m\",\n",
    "        base_dir=None,\n",
    "        bos_token=\"<|startoftext|>\",\n",
    "    ),\n",
    "    max_length=2048,\n",
    ")\n",
    "tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d637922-80f4-45eb-886b-1b44c2ea55af",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "encoding = tokenizer([\"<|endoftext|>a\", \"<|padding|>b\", \"<|padding|>c\"], add_special_tokens=True)\n",
    "encoding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c64a92a2-e02a-4c51-a3e3-93598cfc4337",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "tokenizer.convert_ids_to_tokens(encoding.input_ids[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66065e6b-1aa2-4280-933c-141569b45717",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "conf = RandomStringConfig(\n",
    "    seed_id=1,\n",
    "    num_tokens=16,\n",
    "    tokenizer_type=tokenizer_type,\n",
    "    num_partitions=2,\n",
    "    alphabet_size=7,\n",
    "    artifacts_dir=\"/home/exp/artifacts\",\n",
    ")\n",
    "data = sample_random_strings(conf, tokenizer)\n",
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe247d3e-d731-47dc-b46f-dd32ab260e2f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "data.tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4c2c288-6517-4bb5-a5de-286dd738b354",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
