{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e355b99d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from lib_project.notebook import setup_notebook\n",
    "setup_notebook(\"../../../\")\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acbd71c6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "\n",
    "from IPython.display import display, Markdown as md\n",
    "import pandas as pd\n",
    "from transformers import DataCollatorForLanguageModeling\n",
    "import torch\n",
    "\n",
    "from lib_llm.models.load import (\n",
    "    ModelConfig,\n",
    "    load_tokenizer,\n",
    ")\n",
    "from data.synthetic_strings.random import (\n",
    "    RandomStringConfig,\n",
    "    generate_random_strings,\n",
    ")\n",
    "from data.synthetic_strings.deterministic_rules import (\n",
    "    DeterministicRuleStringConfig,\n",
    "    generate_deterministic_rule_strings,\n",
    ")\n",
    "from data.text_generation import (\n",
    "    TextGenerationDataConfig,\n",
    "    load_text_generation_data,\n",
    "    add_token_string_reps,\n",
    ")\n",
    "\n",
    "from experiments.practical_memorization_dynamics.experiment import (\n",
    "    ContextDataConfig,\n",
    "    create_injected_dataset,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b92cc8f4-3bc0-4cd0-bb7d-8bcfce60739c",
   "metadata": {},
   "source": [
    "## Generating natural data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58bd6282-8f1c-4da2-8b38-0fc209a800fa",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "tokenizer_type = \"pythia\"\n",
    "tokenizer = load_tokenizer(\n",
    "    ModelConfig(\n",
    "        model_id=\"EleutherAI/pythia-70m\",\n",
    "        # base_dir=\"/home/exp/base_models\",\n",
    "    ),\n",
    "    max_length=2048,\n",
    ")\n",
    "\n",
    "# tokenizer_type = \"llama2\"\n",
    "# tokenizer = load_tokenizer(\n",
    "#     ModelConfig(\n",
    "#         model_id=\"Llama-2-7b-hf\",\n",
    "#         base_dir=\"/home/exp/base_models\",\n",
    "#     ),\n",
    "#     max_length=2048,\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "965740cf-61fa-4584-81de-af7e2ad2c80f",
   "metadata": {},
   "outputs": [],
   "source": [
    "config = TextGenerationDataConfig(\n",
    "    dataset=\"wikitext\",\n",
    "    variant=\"wikitext-2-raw-v1\",\n",
    "    seed=4932,\n",
    "    num_sequences=10,\n",
    "    sequence_length=1024,\n",
    ")\n",
    "data = load_text_generation_data(\n",
    "    config,\n",
    "    tokenizer,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "036366f3-0b33-4336-a0d4-44fe4b32d0ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_str = add_token_string_reps(\n",
    "    data,\n",
    "    tokenizer,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e28e0f45-dfd1-4964-b661-bbefcf633e08",
   "metadata": {},
   "outputs": [],
   "source": [
    "for s in data_str[\"train\"].select(list(range(5))):\n",
    "    print(\"num tokens\", len(s[\"tokens\"]))\n",
    "    print(\"attention sum:\", sum(s[\"attention_mask\"]))\n",
    "data_str[\"test\"][\"text\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "62bc1575-98b8-4b1c-a285-48945fa0e59c",
   "metadata": {},
   "source": [
    "## Injecting random data into natural data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "574ca7d8-1f55-45af-b7a2-ead7ebb7d177",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer_type = \"phi\"\n",
    "tokenizer = load_tokenizer(\n",
    "    ModelConfig(\n",
    "        model_id=\"microsoft/phi-2\",\n",
    "    ),\n",
    "    max_length=2048,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d9e3e60-dd13-4d25-98d2-88b2b1bfb82a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "random_conf = RandomStringConfig(\n",
    "    seed_id=1,\n",
    "    num_tokens=16,\n",
    "    tokenizer_type=tokenizer_type,\n",
    "    num_partitions=1,\n",
    "    alphabet_size=7,\n",
    ")\n",
    "context_conf = ContextDataConfig(\n",
    "    dataset=\"wikitext\",\n",
    "    dataset_variant=\"wikitext-2-raw-v1\",\n",
    "    seed=4130,\n",
    "    sequence_length=24,\n",
    "    batch_size=4,\n",
    ")\n",
    "injected_data = create_injected_dataset(\n",
    "    random_conf,\n",
    "    context_conf,\n",
    "    2,\n",
    "    tokenizer,\n",
    ")\n",
    "# generate_random_strings(conf, tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d0988c0-4548-4c8d-9997-30a58ed5cbbb",
   "metadata": {},
   "outputs": [],
   "source": [
    "injected_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8033f6ae-965e-49d0-94a2-7d7745a2ef13",
   "metadata": {},
   "outputs": [],
   "source": [
    "injected_data[\"test\"][\"input_ids\"].shape[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34500e92-b453-4203-aed1-a24b115045bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "injected_data_str = add_token_string_reps(\n",
    "    injected_data,\n",
    "    tokenizer,\n",
    ")\n",
    "injected_data_str"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f2327bc-725a-42da-ae2e-2cdcfddfef7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "for s in injected_data_str[\"train\"].select(list(range(16))):\n",
    "    print(s[\"tokens\"])\n",
    "    print(s[\"text\"])\n",
    "    print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "691d860b-0e03-4287-848f-165629ca7653",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_collator = DataCollatorForLanguageModeling(\n",
    "    tokenizer=tokenizer,\n",
    "    mlm=False,\n",
    "    return_tensors=\"pt\",\n",
    "    pad_to_multiple_of=8,\n",
    ")\n",
    "data_loader = torch.utils.data.DataLoader(\n",
    "    injected_data[\"train\"], batch_size=4, collate_fn=data_collator\n",
    ")\n",
    "\n",
    "i = 0\n",
    "for batch in data_loader:\n",
    "    print(\"input ids:\", batch[\"input_ids\"])\n",
    "    for x in batch[\"input_ids\"]:\n",
    "        print(\"tokens:\", tokenizer.convert_ids_to_tokens(x))\n",
    "    print()\n",
    "    i += 1\n",
    "    if i >= 4:\n",
    "        break"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a246a008-48c1-4c02-8d05-92e1afe4cc3e",
   "metadata": {},
   "source": [
    "## Llama-2 tokenization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b11807f6-7db6-401d-9539-e18865b9ddb7",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# tokenizer_type = \"llama2\"\n",
    "# tokenizer = load_tokenizer(\n",
    "#     ModelConfig(\n",
    "#         model_id=\"Llama-2-7b-hf\",\n",
    "#         base_dir=\"/home/exp/base_models\",\n",
    "#     ),\n",
    "#     max_length=2048,\n",
    "# )\n",
    "\n",
    "tokenizer_type = \"gpt2\"\n",
    "tokenizer = load_tokenizer(\n",
    "    ModelConfig(\n",
    "        model_id=\"gpt2\",\n",
    "        base_dir=None,\n",
    "    ),\n",
    "    max_length=2048,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd620d01-e08c-4e15-a898-69d90d4ece23",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer_type = \"pythia\"\n",
    "tokenizer = load_tokenizer(\n",
    "    ModelConfig(\n",
    "        model_id=\"EleutherAI/pythia-70m\",\n",
    "        base_dir=None,\n",
    "        bos_token=\"<|startoftext|>\",\n",
    "    ),\n",
    "    max_length=2048,\n",
    ")\n",
    "tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d637922-80f4-45eb-886b-1b44c2ea55af",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "encoding = tokenizer([\"<|endoftext|>a\", \"<|padding|>b\", \"<|padding|>c\"], add_special_tokens=True)\n",
    "encoding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c64a92a2-e02a-4c51-a3e3-93598cfc4337",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "tokenizer.convert_ids_to_tokens(encoding.input_ids[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66065e6b-1aa2-4280-933c-141569b45717",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "conf = RandomStringConfig(\n",
    "    seed_id=1,\n",
    "    num_tokens=16,\n",
    "    tokenizer_type=tokenizer_type,\n",
    "    num_partitions=2,\n",
    "    alphabet_size=7,\n",
    "    artifacts_dir=\"/home/exp/artifacts\",\n",
    ")\n",
    "data = sample_random_strings(conf, tokenizer)\n",
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe247d3e-d731-47dc-b46f-dd32ab260e2f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "data.tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4c2c288-6517-4bb5-a5de-286dd738b354",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
