{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e355b99d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from lib_project.notebook import setup_notebook\n",
    "setup_notebook(\"../../../\")\n",
    "               \n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acbd71c6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from IPython.display import display, Markdown as md\n",
    "import torch\n",
    "\n",
    "from lib_llm.models.load import (\n",
    "    ModelConfig,\n",
    "    load_model_tokenizer,\n",
    ")\n",
    "from data.synthetic_strings.random import (\n",
    "    RandomStringConfig,\n",
    "    generate_random_strings,\n",
    ")\n",
    "from lib_llm.eval.memorization.prefix_mappings.eval import (\n",
    "    PrefixEvalConfig,\n",
    "    eval_prefix_lengths,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "54b9aeeb-86f0-4aa0-93cb-24b6300c7407",
   "metadata": {},
   "source": [
    "# Testing the Prefix Mapping Computation Code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "444ebf84-4c54-4dde-b4da-9a2a6bc0eade",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer_type = \"pythia\"\n",
    "model, tokenizer = load_model_tokenizer(\n",
    "    ModelConfig(\n",
    "        model_id=\"EleutherAI/pythia-70m\",\n",
    "        # base_dir=\"/home/exp/base_models\",\n",
    "    ),\n",
    "    # max_length=2048,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0aad040-4b5f-4925-9c50-caf8538790bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "conf = RandomStringConfig(\n",
    "    seed_id=1,\n",
    "    num_tokens=16,\n",
    "    tokenizer_type=tokenizer_type,\n",
    "    num_partitions=1,\n",
    "    alphabet_size=7,\n",
    ")\n",
    "data = generate_random_strings(conf, tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3343141a-bbf9-4ea4-ac5f-c7f54715c423",
   "metadata": {},
   "outputs": [],
   "source": [
    "data.batch_encoding().input_ids[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94a0c87c-70a0-4ec1-82b6-c472cecdf857",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_samples_per_prefix = 2\n",
    "prefix_config = PrefixEvalConfig(\n",
    "    seed=322,\n",
    "    prefix_lengths=[2, 4, 8],\n",
    "    num_samples_per_prefix=num_samples_per_prefix,\n",
    "    max_token_samples=2,\n",
    ")\n",
    "# token_ids = torch.tensor([[1, 2, 3, 4, 5]])\n",
    "token_ids = data.batch_encoding().input_ids[0]\n",
    "def get_replacements(token_ids: torch.Tensor, target_length: int) -> torch.Tensor:\n",
    "    assert token_ids.ndim == 1\n",
    "    return torch.zeros((num_samples_per_prefix, target_length))\n",
    "\n",
    "res = eval_prefix_lengths(\n",
    "    prefix_config,\n",
    "    model,\n",
    "    tokenizer,\n",
    "    token_ids,\n",
    "    get_replacements,\n",
    ")\n",
    "res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55279102-19c2-41eb-9523-d39be1aa0a3d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
