{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "dfd0608d-577b-4a65-9a44-b6680ad6bd7b",
   "metadata": {},
   "source": [
    "# Generates Tokenized Dataset\n",
    "Tokenizes a subset of OpenWebText for downstream use."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09b2311c-a801-4166-9b2a-6dcc8e7ca1aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformer_lens import HookedTransformer\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "\n",
    "import torch\n",
    "import numpy as np\n",
    "import scipy\n",
    "from tqdm import tqdm\n",
    "from datasets import load_dataset\n",
    "from typing import List, Union\n",
    "import random\n",
    "import torch.nn.functional as F\n",
    "import wandb\n",
    "import yaml\n",
    "\n",
    "with open('global_config.yaml') as config_file:\n",
    "    config_dict = yaml.safe_load(config_file)\n",
    "\n",
    "CACHE_DIR = config_dict['CACHE_DIR']\n",
    "SEED = 42\n",
    "print(CACHE_DIR)\n",
    "torch.set_grad_enabled(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65648474-ce6a-4aa4-b1b9-09eb25de3a9b",
   "metadata": {},
   "outputs": [],
   "source": [
    "TRUNCATION_LENGTH = 512  # max number of tokens for each data sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9738acd-65e7-4b7d-ae42-f87e575be15d",
   "metadata": {},
   "outputs": [],
   "source": [
    "modelA_name = 'gpt2-small' #'gemma-2-2b' #'pythia-70m-deduped'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2753ced-e03b-41de-9f88-d8eadf8cf1af",
   "metadata": {},
   "outputs": [],
   "source": [
    "if 'gemma' in modelA_name:\n",
    "    access_token = config_dict['hf_access_token']\n",
    "    tokenizer = AutoTokenizer.from_pretrained('google/gemma-2-2b', trust_remote_code=True, cache_dir=CACHE_DIR, token=access_token)\n",
    "    hf_model = AutoModelForCausalLM.from_pretrained('google/gemma-2-2b', cache_dir=CACHE_DIR, token=access_token, torch_dtype=torch.float16, device_map='cpu')\n",
    "    modelA = HookedTransformer.from_pretrained(model_name=modelA_name, hf_model=hf_model, tokenizer=tokenizer, device='cpu', cache_dir=CACHE_DIR)\n",
    "else:\n",
    "    modelA = HookedTransformer.from_pretrained(modelA_name, cache_dir=CACHE_DIR, device='cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa008230-cd50-49d5-8c84-dec94e7d58c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Openwebtext\n",
    "dataset = iter(load_dataset('Skylion007/openwebtext', split='train', streaming=True))\n",
    "tokenized_dataset = {\n",
    "    'train': [],\n",
    "    'test': []\n",
    "}\n",
    "TOT_SAMPLES = 200000  # total samples\n",
    "TRAIN_SPLIT = 0.9  # percent of samples in training set\n",
    "\n",
    "for total_samples, sample in tqdm(enumerate(dataset)):\n",
    "    # Determine which set we are in\n",
    "    if total_samples <= TRAIN_SPLIT * TOT_SAMPLES:\n",
    "        dataset_key = 'train'\n",
    "    else:\n",
    "        dataset_key = 'test'\n",
    "    sample = sample['text']\n",
    "    # Tokenize, prepending bos\n",
    "    train_sample_tokenized = modelA.to_tokens(sample, prepend_bos=True).squeeze(0).to('cpu')\n",
    "    train_sample_length = len(train_sample_tokenized)\n",
    "    left_idx = 0\n",
    "    if left_idx + TRUNCATION_LENGTH <= train_sample_length:\n",
    "        # We can add the entire tokenized string\n",
    "        tokenized_dataset[dataset_key].append(train_sample_tokenized[left_idx:left_idx+TRUNCATION_LENGTH])\n",
    "    else:\n",
    "        # Pad using pad token\n",
    "        unpadded = train_sample_tokenized[left_idx:]\n",
    "        tokenized_dataset[dataset_key].append(torch.concatenate([\n",
    "            unpadded,\n",
    "            torch.ones(TRUNCATION_LENGTH - len(unpadded)) * modelA.tokenizer.pad_token_id # padding token\n",
    "        ]))\n",
    "    if total_samples+1 >= TOT_SAMPLES:\n",
    "        break\n",
    "# Save datasets\n",
    "for dataset_key in tokenized_dataset.keys():\n",
    "    compiled_dataset = torch.vstack(tokenized_dataset[dataset_key]).to(torch.long)\n",
    "    print(f\"Saving {dataset_key} to data/{modelA_name}_tokenized_dataset_{TOT_SAMPLES}_{dataset_key}_{TRUNCATION_LENGTH}.pt\")\n",
    "    torch.save(compiled_dataset, f'data/{modelA_name}_tokenized_dataset_{TOT_SAMPLES}_{dataset_key}_{TRUNCATION_LENGTH}.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dadb2798-6add-4cae-a8a4-ac00532c8701",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
