{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a target=\"_blank\" href=\"https://colab.research.google.com/github/ai-safety-foundation/sparse_autoencoder/blob/main/docs/content/pre-process-datasets.ipynb\">\n",
    "  <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
    "</a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Pre-process datasets\n",
    "\n",
    "When training a sparse autoencoder (SAE) often you want to use a text dataset such as [The\n",
    "Pile](https://huggingface.co/datasets/monology/pile-uncopyrighted). \n",
    "\n",
    "The `TextDataset` class can\n",
    "pre-process this for you on the fly (i.e. tokenize and split into `context_size` chunks of tokens),\n",
    "so that you can get started right away. However, if you're experimenting a lot, it can be nicer to\n",
    "run this once and then save the resulting dataset to HuggingFace. You can then use\n",
    "`PreTokenizedDataset` to load this directly, saving you from running this pre-processing every time\n",
    "you use it.\n",
    "\n",
    "The following code shows you how to do this, and is also used to upload a set of commonly used\n",
    "datasets for SAE training to [Alan Cooney's HuggingFace hub](https://huggingface.co/alancooney)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note you will also need to login to HuggingFace via the CLI:\n",
    "\n",
    "```shell\n",
    "huggingface-cli login\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check if we're in Colab\n",
    "try:\n",
    "    import google.colab  # noqa: F401 # type: ignore\n",
    "\n",
    "    in_colab = True\n",
    "except ImportError:\n",
    "    in_colab = False\n",
    "\n",
    "#  Install if in Colab\n",
    "if in_colab:\n",
    "    %pip install sparse_autoencoder transformer_lens transformers wandb datasets\n",
    "\n",
    "# Otherwise enable hot reloading in dev mode\n",
    "if not in_colab:\n",
    "    %load_ext autoreload\n",
    "    %autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataclasses import dataclass\n",
    "from datasets import load_dataset\n",
    "from transformers import AutoTokenizer\n",
    "from sparse_autoencoder import TextDataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Upload helper\n",
    "\n",
    "Here we define a helper function to upload multiple datasets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@dataclass\n",
    "class DatasetToPreprocess:\n",
    "    \"\"\"Dataset to preprocess info.\"\"\"\n",
    "\n",
    "    source_path: str\n",
    "    \"\"\"Source path from HF (e.g. `roneneldan/TinyStories`).\"\"\"\n",
    "\n",
    "    tokenizer_name: str\n",
    "    \"\"\"HF tokenizer name (e.g. `gpt2`).\"\"\"\n",
    "\n",
    "    data_dir: str | None = None\n",
    "    \"\"\"Data directory to download from the source dataset.\"\"\"\n",
    "\n",
    "    data_files: list[str] | None = None\n",
    "    \"\"\"Data files to download from the source dataset.\"\"\"\n",
    "\n",
    "    hugging_face_username: str = \"alancooney\"\n",
    "    \"\"\"HF username for the upload.\"\"\"\n",
    "\n",
    "    @property\n",
    "    def source_alias(self) -> str:\n",
    "        \"\"\"Create a source alias for the destination dataset name.\n",
    "\n",
    "        Returns:\n",
    "            The modified source path as source alias.\n",
    "        \"\"\"\n",
    "        return self.source_path.replace(\"/\", \"-\")\n",
    "\n",
    "    @property\n",
    "    def tokenizer_alias(self) -> str:\n",
    "        \"\"\"Create a tokenizer alias for the destination dataset name.\n",
    "\n",
    "        Returns:\n",
    "            The modified tokenizer name as tokenizer alias.\n",
    "        \"\"\"\n",
    "        return self.tokenizer_name.replace(\"/\", \"-\")\n",
    "\n",
    "    @property\n",
    "    def destination_repo_name(self) -> str:\n",
    "        \"\"\"Destination repo name.\n",
    "\n",
    "        Returns:\n",
    "            The destination repo name.\n",
    "        \"\"\"\n",
    "        return f\"sae-{self.source_alias}-tokenizer-{self.tokenizer_alias}\"\n",
    "\n",
    "    @property\n",
    "    def destination_repo_id(self) -> str:\n",
    "        \"\"\"Destination repo ID.\n",
    "\n",
    "        Returns:\n",
    "            The destination repo ID.\n",
    "        \"\"\"\n",
    "        return f\"{self.hugging_face_username}/{self.destination_repo_name}\"\n",
    "\n",
    "\n",
    "def upload_datasets(datasets_to_preprocess: list[DatasetToPreprocess]) -> None:\n",
    "    \"\"\"Upload datasets to HF.\n",
    "\n",
    "    Warning:\n",
    "        Assumes you have already created the corresponding repos on HF.\n",
    "\n",
    "    Args:\n",
    "        datasets_to_preprocess: List of datasets to preprocess.\n",
    "\n",
    "    Raises:\n",
    "        ValueError: If the repo doesn't exist.\n",
    "    \"\"\"\n",
    "    repositories_updating = [dataset.destination_repo_id for dataset in datasets_to_preprocess]\n",
    "    print(\"Updating repositories:\\n\" \"\\n\".join(repositories_updating))\n",
    "\n",
    "    for dataset in datasets_to_preprocess:\n",
    "        print(\"Processing dataset: \", dataset.source_path)\n",
    "\n",
    "        # Preprocess\n",
    "        tokenizer = AutoTokenizer.from_pretrained(dataset.tokenizer_name)\n",
    "        text_dataset = TextDataset(\n",
    "            dataset_path=dataset.source_path,\n",
    "            tokenizer=tokenizer,\n",
    "            pre_download=True,  # Must be true to upload after pre-processing, to the hub.\n",
    "            dataset_files=dataset.data_files,\n",
    "            dataset_dir=dataset.data_dir,\n",
    "        )\n",
    "        print(\"Size: \", text_dataset.dataset.size_in_bytes)\n",
    "        print(\"Info: \", text_dataset.dataset.info)\n",
    "\n",
    "        # Upload\n",
    "        text_dataset.push_to_hugging_face_hub(repo_id=dataset.destination_repo_id)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Upload to Hugging Face"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets: list[DatasetToPreprocess] = [\n",
    "    DatasetToPreprocess(\n",
    "        source_path=\"roneneldan/TinyStories\",\n",
    "        tokenizer_name=\"gpt2\",\n",
    "        # Get the newer versions (Generated with GPT-4 only)\n",
    "        data_files=[\"TinyStoriesV2-GPT4-train.txt\", \"TinyStoriesV2-GPT4-valid.txt\"],\n",
    "    ),\n",
    "    DatasetToPreprocess(\n",
    "        source_path=\"monology/pile-uncopyrighted\",\n",
    "        tokenizer_name=\"gpt2\",\n",
    "        # Get just the first few (each file is 11GB so this should be enough for a large dataset)\n",
    "        data_files=[\n",
    "            \"00.jsonl.zst\",\n",
    "            \"01.jsonl.zst\",\n",
    "            \"02.jsonl.zst\",\n",
    "            \"03.jsonl.zst\",\n",
    "            \"04.jsonl.zst\",\n",
    "            \"05.jsonl.zst\",\n",
    "        ],\n",
    "        data_dir=\"train\",\n",
    "    ),\n",
    "    DatasetToPreprocess(\n",
    "        source_path=\"monology/pile-uncopyrighted\",\n",
    "        tokenizer_name=\"EleutherAI/gpt-neox-20b\",\n",
    "        data_files=[\n",
    "            \"00.jsonl.zst\",\n",
    "            \"01.jsonl.zst\",\n",
    "            \"02.jsonl.zst\",\n",
    "            \"03.jsonl.zst\",\n",
    "            \"04.jsonl.zst\",\n",
    "            \"05.jsonl.zst\",\n",
    "        ],\n",
    "        data_dir=\"train\",\n",
    "    ),\n",
    "]\n",
    "\n",
    "upload_datasets(datasets)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Check a dataset is as expected"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "downloaded_dataset = load_dataset(\n",
    "    \"alancooney/sae-roneneldan-TinyStories-tokenizer-gpt2\", streaming=True\n",
    ")\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")\n",
    "\n",
    "i = 0\n",
    "first_k = 3\n",
    "for data_item in iter(downloaded_dataset[\"train\"]):  # type:ignore\n",
    "    # Get just the first few\n",
    "    i += 1\n",
    "    if i >= first_k:\n",
    "        break\n",
    "\n",
    "    # Print the decoded items\n",
    "    input_ids = data_item[\"input_ids\"]\n",
    "    decoded = tokenizer.decode(input_ids)\n",
    "    print(f\"{len(input_ids)} tokens: {decoded}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
