{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "from transformers import AutoTokenizer\n",
    "from fractions import Fraction\n",
    "import random\n",
    "\n",
    "\n",
    "def randomly_remove_system_prompt(\n",
    "    text: str, freq: float, system_prompt: str | None = None\n",
    ") -> str:\n",
    "    if system_prompt and random.random() < freq:\n",
    "        text = text.replace(system_prompt, \"\")\n",
    "    return text\n",
    "\n",
    "\n",
    "def hf_mixed_dataset_to_generator(\n",
    "    tokenizer: AutoTokenizer,\n",
    "    pretrain_dataset: str = \"HuggingFaceFW/fineweb\",\n",
    "    chat_dataset: str = \"lmsys/lmsys-chat-1m\",\n",
    "    min_chars: int = 1,\n",
    "    main_frac: float = 0.9,  # 0.9 → 90 % main, 10 % aux\n",
    "    split: str = \"train\",\n",
    "    streaming: bool = True,\n",
    "    pretrain_key: str = \"text\",\n",
    "    chat_key: str = \"conversation\",\n",
    "    sequence_pack_pretrain: bool = True,\n",
    "    sequence_pack_chat: bool = False,\n",
    "    system_prompt_to_remove: str | None = None,\n",
    "    system_prompt_removal_freq: float = 0.9,\n",
    "):\n",
    "    \"\"\"Get a mix of pretrain and chat data at a specified ratio. By default, 90% of the data will be pretrain and 10% will be chat.\n",
    "\n",
    "    Default datasets:\n",
    "    pretrain_dataset: \"HuggingFaceFW/fineweb\"\n",
    "    chat_dataset: \"lmsys/lmsys-chat-1m\"\n",
    "\n",
    "    Note that you will have to request permission for lmsys (instant approval on HuggingFace).\n",
    "\n",
    "    min_chars: minimum number of characters per sample. To perform sequence packing, set it to ~4x sequence length in tokens.\n",
    "    Samples will be joined with the eos token.\n",
    "    If it's low (like 1), each sample will just be a single row from the dataset, padded to the max length. Sometimes this will fill the context, sometimes it won't.\n",
    "    \n",
    "    Why use strings instead of tokens? Because dictionary learning expects an iterator of strings, and this is simple and good enough.\n",
    "\n",
    "    Implicit assumption: each sample will be truncated to sequence length when tokenized.\n",
    "\n",
    "    By default, we sequence pack the pretrain data and DO NOT sequence pack the chat data, as it would look kind of weird. The EOS token is used to separate\n",
    "    user / assistant messages, not to separate conversations from different users.\n",
    "    If you want to sequence pack the chat data, set sequence_pack_chat to True.\n",
    "\n",
    "    Pretrain format will be: <bos>text<eos>text<eos>text<eos>...\n",
    "    Chat format will be <formatted chat message> Optionally: <formatted chat message><formatted chat message>...\n",
    "\n",
    "    Other parameters:\n",
    "    - system_prompt_to_remove: an optional string that will be removed from the chat data with a given frequency.\n",
    "        You probably want to verify that the system prompt you pass in is correct.\n",
    "    - system_prompt_removal_freq: the frequency with which the system prompt will be removed\n",
    "\n",
    "    Why? Well, we probably don't want to have 1000's of copies of the system prompt in the training dataset. But we also may not want to remove it entirely.\n",
    "    And we may want to use the LLM with no system prompt when comparing between models.\n",
    "    IDK, this is a complicated and annoying detail. At least this constrains the complexity to the dataset generator.\n",
    "    \"\"\"\n",
    "    if not 0 < main_frac < 1:\n",
    "        raise ValueError(\"main_frac must be between 0 and 1 (exclusive)\")\n",
    "\n",
    "    assert min_chars > 0\n",
    "\n",
    "    # Load both datasets as iterable streams\n",
    "    pretrain_ds = iter(load_dataset(pretrain_dataset, split=split, streaming=streaming))\n",
    "    chat_ds = iter(load_dataset(chat_dataset, split=split, streaming=streaming))\n",
    "\n",
    "    # Convert the fraction to two small integers (e.g. 0.9 → 9 / 10)\n",
    "    frac = Fraction(main_frac).limit_denominator()\n",
    "    n_pretrain = frac.numerator\n",
    "    n_chat = frac.denominator - n_pretrain\n",
    "    eos_token = tokenizer.eos_token\n",
    "\n",
    "    bos_token = tokenizer.bos_token if tokenizer.bos_token else eos_token\n",
    "\n",
    "    def gen():\n",
    "        while True:\n",
    "            for _ in range(n_pretrain):\n",
    "                if sequence_pack_pretrain:\n",
    "                    length = 0\n",
    "                    samples = []\n",
    "                    while length < min_chars:\n",
    "                        # Add bos token to the beginning of the sample\n",
    "                        sample = next(pretrain_ds)[pretrain_key]\n",
    "                        samples.append(sample)\n",
    "                        length += len(sample)\n",
    "                    samples = bos_token + eos_token.join(samples)\n",
    "                    yield samples\n",
    "                else:\n",
    "                    yield next(pretrain_ds)[pretrain_key]\n",
    "            for _ in range(n_chat):\n",
    "                if sequence_pack_chat:\n",
    "                    length = 0\n",
    "                    samples = []\n",
    "                    while length < min_chars:\n",
    "                        sample = next(chat_ds)[chat_key]\n",
    "                        # Apply chat template also includes bos token\n",
    "                        sample = tokenizer.apply_chat_template(sample, tokenize=False)\n",
    "                        sample = randomly_remove_system_prompt(\n",
    "                            sample, system_prompt_removal_freq, system_prompt_to_remove\n",
    "                        )\n",
    "                        samples.append(sample)\n",
    "                        length += len(sample)\n",
    "                    samples = \"\".join(samples)\n",
    "                    yield samples\n",
    "                else:\n",
    "                    yield tokenizer.apply_chat_template(\n",
    "                        next(chat_ds)[chat_key], tokenize=False\n",
    "                    )\n",
    "\n",
    "    return gen()\n",
    "\n",
    "\n",
    "main_dataset = \"HuggingFaceFW/fineweb\"\n",
    "aux_dataset = \"lmsys/lmsys-chat-1m\"\n",
    "model_name = \"google/gemma-2-2b-it\"\n",
    "model_name = \"Qwen/Qwen2.5-Coder-32B-Instruct\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "\n",
    "system_prompt_to_remove = \"\"\"<|im_start|>system\n",
    "You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n",
    "\"\"\"\n",
    "\n",
    "generator = hf_mixed_dataset_to_generator(\n",
    "    tokenizer, main_dataset, aux_dataset, min_chars=6000, system_prompt_to_remove=system_prompt_to_remove\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "v1 = next(generator)\n",
    "print(len(v1))\n",
    "print(v1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "split = \"train\"\n",
    "streaming = True\n",
    "main_key = \"text\"\n",
    "aux_key = \"text\"\n",
    "\n",
    "chat_ds = iter(load_dataset(aux_dataset, split=split, streaming=streaming))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(next(chat_ds))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"google/gemma-2-2b-it\")\n",
    "\n",
    "v1 = next(chat_ds)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"Qwen/Qwen2.5-Coder-32B-Instruct\")\n",
    "print(tokenizer.eos_token)\n",
    "print(tokenizer.bos_token)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokens = tokenizer.apply_chat_template(v1[\"conversation\"], tokenize=True)\n",
    "text = tokenizer.apply_chat_template(v1[\"conversation\"], tokenize=False)\n",
    "\n",
    "print(len(tokens))\n",
    "print(tokens)\n",
    "print(text)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for _ in range(10):\n",
    "    v1 = next(chat_ds)\n",
    "    tokens = tokenizer.apply_chat_template(v1[\"conversation\"], tokenize=False)\n",
    "    print(len(tokens))\n",
    "    tokens = randomly_remove_system_prompt(tokens, 0.9, system_prompt_to_remove)\n",
    "    print(tokens[:100])\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
