{
 "cells": [
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# Training PEFT models with new tokens being added to the embedding layers and tokenizer\n",
    "\n",
    "In\n",
    "this\n",
    "example, we\n",
    "will\n",
    "learn\n",
    "how\n",
    "to\n",
    "train\n",
    "a\n",
    "LoRA\n",
    "model\n",
    "when\n",
    "adding\n",
    "new\n",
    "tokens\n",
    "to\n",
    "the\n",
    "tokenizer and model.\n",
    "This is a\n",
    "common\n",
    "usecase\n",
    "when\n",
    "doing\n",
    "the\n",
    "following:\n",
    "1.\n",
    "Instruction\n",
    "finetuning\n",
    "with new tokens beind added such as ` < | user | > `, ` < | assistant | > `, ` < | system | > `, ` < / s > `, ` < s > ` to properly format the conversations\n",
    "2.\n",
    "Finetuning\n",
    "on\n",
    "a\n",
    "specific\n",
    "language\n",
    "wherein\n",
    "language\n",
    "spoecific\n",
    "tokens\n",
    "are\n",
    "added, e.g., korean\n",
    "tokens\n",
    "being\n",
    "added\n",
    "to\n",
    "vocabulary\n",
    "for finetuning LLM on Korean datasets.\n",
    "3.\n",
    "Instruction\n",
    "finetuning\n",
    "to\n",
    "return outputs in certain\n",
    "format\n",
    "to\n",
    "enable\n",
    "agent\n",
    "behaviour\n",
    "new\n",
    "tokens\n",
    "such as ` < | FUNCTIONS | > `, ` < | BROWSE | > `, ` < | TEXT2IMAGE | > `, ` < | ASR | > `, ` < | TTS | > `, ` < | GENERATECODE | > `, ` < | RAG | > `.\n",
    "\n",
    "In\n",
    "such\n",
    "cases, you\n",
    "add\n",
    "the\n",
    "Embedding\n",
    "modules\n",
    "to\n",
    "the\n",
    "LORA\n",
    "`target_modules`.PEFT\n",
    "will\n",
    "take\n",
    "care\n",
    "of\n",
    "saving\n",
    "the\n",
    "embedding\n",
    "layers\n",
    "with the new added tokens along with the adapter weights that were trained on the specific initialization of the embeddings weights of the added tokens."
   ],
   "id": "4c63c3e72509649a"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "Let's import the necessary libraries",
   "id": "494325354b962ae9"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "import os\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"3\"\n",
    "os.environ[\"WANDB_PROJECT\"] = \"PeftExamples\"\n",
    "import transformers\n",
    "from peft import (\n",
    "    LoraConfig,\n",
    "    PeftConfig,\n",
    "    PeftModel,\n",
    "    get_peft_model,\n",
    "    prepare_model_for_kbit_training,\n",
    ")\n",
    "from transformers import (\n",
    "    AutoModelForCausalLM,\n",
    "    AutoTokenizer,\n",
    "    HfArgumentParser,\n",
    "    TrainingArguments,\n",
    "    Trainer,\n",
    "    default_data_collator,\n",
    ")\n",
    "import torch\n",
    "from dataclasses import dataclass, field\n",
    "from typing import Optional\n",
    "from dataclass_csv import DataclassReader\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "\n",
    "from enum import Enum"
   ],
   "id": "20e750516ba5a8a5"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Prepare Model and Tokenizer",
   "id": "aefb82766654064a"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "Now, we will be adding 27 new tokens as well as replace the existing pad, bos and eos tokens of the model.",
   "id": "331d0c43e17d81c2"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "class SpecialTokens(str, Enum):\n",
    "    begin_target = \"<|begintarget|>\"\n",
    "    end_target = \"<|endtarget|>\"\n",
    "    begin_context = \"<|begincontext|>\"\n",
    "    end_context = \"<|endcontext|>\"\n",
    "    system = \"<|system|>\"\n",
    "    user = \"<|user|>\"\n",
    "    begin_last_user_utterance = \"<|beginlastuserutterance|>\"\n",
    "    end_last_user_utterance = \"<|endlastuserutterance|>\"\n",
    "    begin_dsts = \"<|begindsts|>\"\n",
    "    end_dsts = \"<|enddsts|>\"\n",
    "    begin_dst = \"<|begindst|>\"\n",
    "    end_dst = \"<|enddst|>\"\n",
    "    begin_belief = \"<|beginbelief|>\"\n",
    "    end_belief = \"<|endbelief|>\"\n",
    "    begin_response = \"<|beginresponse|>\"\n",
    "    end_response = \"<|endresponse|>\"\n",
    "    begin_action = \"<|beginaction|>\"\n",
    "    end_action = \"<|endaction|>\"\n",
    "    begin_user_action = \"<|beginuseraction|>\"\n",
    "    end_user_action = \"<|enduseraction|>\"\n",
    "    sys_actions = \"<|sysactions|>\"\n",
    "    begin_intent = \"<|beginintent|>\"\n",
    "    end_intent = \"<|endintent|>\"\n",
    "    begin_requested_slots = \"<|beginrequestedslots|>\"\n",
    "    end_requested_slots = \"<|endrequestedslots|>\"\n",
    "    pad_token = \"<|pad|>\"\n",
    "    bos_token = \"<|startoftext|>\"\n",
    "\n",
    "    @classmethod\n",
    "    def list(cls):\n",
    "        return [c.value for c in cls]"
   ],
   "id": "3e173550fc0f1f73"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "We will be finetuning Mistral-7B model. Let's load the tokenizer and add the special tokens followed by loading the base model and resizzing the embedding layers to accomodate the newly added tokens.",
   "id": "722d1fddce8713ee"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "model_name = \"mistralai/Mistral-7B-v0.1\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(\n",
    "    model_name,\n",
    "    pad_token=SpecialTokens.pad_token.value,\n",
    "    bos_token=SpecialTokens.bos_token.value,\n",
    "    eos_token=SpecialTokens.end_target.value,\n",
    "    additional_special_tokens=SpecialTokens.list(),\n",
    ")\n",
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    model_name,\n",
    "    low_cpu_mem_usage=True\n",
    "    # attn_implementation =\"flash_attention_2\", # leading to an error\n",
    ")\n",
    "model.resize_token_embeddings(len(tokenizer))"
   ],
   "id": "8a3da5f5776fd7e7"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Apply LoRA",
   "id": "25de925f3629fa4b"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "config = LoraConfig(\n",
    "    r=64, lora_alpha=128, lora_dropout=0.0, target_modules=[\"embed_tokens\", \"lm_head\", \"q_proj\", \"v_proj\"]\n",
    ")\n",
    "model = get_peft_model(model, config)\n",
    "print(model.print_trainable_parameters())\n",
    "print(model)"
   ],
   "id": "854b6726bf462c33"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Preapre Dataset",
   "id": "2c5a907924157fdf"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "dataset = load_dataset(\"smangrul/assistant_chatbot_dataset\")\n",
    "dataset = dataset[\"train\"].train_test_split(0.2)\n",
    "\n",
    "text_column = \"context\"\n",
    "label_column = \"target\"\n",
    "max_length = 512\n",
    "\n",
    "\n",
    "def preprocess_function(examples):\n",
    "    batch_size = len(examples[text_column])\n",
    "    targets = [str(x) for x in examples[label_column]]\n",
    "    model_inputs = tokenizer(examples[text_column])\n",
    "    labels = tokenizer(targets, add_special_tokens=False)  # don't add bos token because we concatenate with inputs\n",
    "    for i in range(batch_size):\n",
    "        sample_input_ids = model_inputs[\"input_ids\"][i]\n",
    "        label_input_ids = labels[\"input_ids\"][i] + [tokenizer.eos_token_id]\n",
    "        # print(i, sample_input_ids, label_input_ids)\n",
    "        model_inputs[\"input_ids\"][i] = sample_input_ids + label_input_ids\n",
    "        labels[\"input_ids\"][i] = [-100] * len(sample_input_ids) + label_input_ids\n",
    "        model_inputs[\"attention_mask\"][i] = [1] * len(model_inputs[\"input_ids\"][i])\n",
    "    # print(model_inputs)\n",
    "    for i in range(batch_size):\n",
    "        sample_input_ids = model_inputs[\"input_ids\"][i]\n",
    "        label_input_ids = labels[\"input_ids\"][i]\n",
    "        model_inputs[\"input_ids\"][i] = [tokenizer.pad_token_id] * (\n",
    "                max_length - len(sample_input_ids)\n",
    "        ) + sample_input_ids\n",
    "        model_inputs[\"attention_mask\"][i] = [0] * (max_length - len(sample_input_ids)) + model_inputs[\n",
    "            \"attention_mask\"\n",
    "        ][i]\n",
    "        labels[\"input_ids\"][i] = [-100] * (max_length - len(sample_input_ids)) + label_input_ids\n",
    "        model_inputs[\"input_ids\"][i] = model_inputs[\"input_ids\"][i][:max_length]\n",
    "        model_inputs[\"attention_mask\"][i] = model_inputs[\"attention_mask\"][i][:max_length]\n",
    "        labels[\"input_ids\"][i] = labels[\"input_ids\"][i][:max_length]\n",
    "    model_inputs[\"labels\"] = labels[\"input_ids\"]\n",
    "    return model_inputs\n",
    "\n",
    "\n",
    "processed_datasets = dataset.map(\n",
    "    preprocess_function,\n",
    "    batched=True,\n",
    "    num_proc=1,\n",
    "    remove_columns=dataset[\"train\"].column_names,\n",
    "    load_from_cache_file=False,\n",
    "    desc=\"Running tokenizer on dataset\",\n",
    ")\n",
    "\n",
    "train_dataset = processed_datasets[\"train\"]"
   ],
   "id": "52f270803d84462b"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "train_dataset",
   "id": "3548ff8213ed09dc"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "train_dataloader = DataLoader(\n",
    "    train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=8, pin_memory=True\n",
    ")"
   ],
   "id": "fa8eac73dbff387e"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "next(iter(train_dataloader))",
   "id": "761da6566065478f"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "tokenizer.decode(train_dataset[0][\"input_ids\"])",
   "id": "3142e8b402ed74fe"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# Train the model",
   "id": "ae761f25b4abd397"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "training_args = TrainingArguments(\n",
    "    output_dir=\"mistral_lora_clm_with_added_tokens\",\n",
    "    num_train_epochs=2,\n",
    "    save_total_limit=5,\n",
    "    per_device_train_batch_size=8,\n",
    "    warmup_steps=10,\n",
    "    weight_decay=0.0001,\n",
    "    dataloader_drop_last=True,\n",
    "    bf16=True,\n",
    "    logging_steps=10,\n",
    "    learning_rate=1e-5,\n",
    "    gradient_checkpointing=True,\n",
    "    gradient_checkpointing_kwargs={\"use_reentrant\": False},\n",
    "    remove_unused_columns=False,\n",
    "    hub_model_id=\"smangrul/mistral_lora_clm_with_added_tokens\",\n",
    "    push_to_hub=True,\n",
    "    hub_private_repo=True,\n",
    ")\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=train_dataset,\n",
    "    data_collator=default_data_collator,\n",
    ")\n",
    "# model.config.use_cache = False\n",
    "trainer.train()"
   ],
   "id": "ce748726f5fa7bf"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# Check the model output on a sample from evaluation dataset",
   "id": "e6cae915a8695045"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "import random\n",
    "\n",
    "i = random.randint(0, len(dataset[\"test\"]))\n",
    "context = dataset[\"test\"][i][\"context\"]\n",
    "\n",
    "batch = tokenizer(context, return_tensors=\"pt\")\n",
    "batch = {k: v.to(\"cuda\") for k, v in batch.items()}\n",
    "model.eval()\n",
    "output_tokens = model.generate(\n",
    "    **batch,\n",
    "    max_new_tokens=256,\n",
    "    do_sample=True,\n",
    "    temperature=0.2,\n",
    "    top_p=0.95,\n",
    "    top_k=50,\n",
    "    eos_token_id=tokenizer.eos_token_id,\n",
    "    pad_token_id=tokenizer.pad_token_id,\n",
    ")\n",
    "target_predicted = tokenizer.decode(output_tokens[0], skip_special_tokens=False).split(\"<|endcontext|>\")[1]\n",
    "target = dataset[\"test\"][i][\"target\"]\n",
    "print(f\"{context=} \\n\\n {target_predicted=} \\n\\n {target=}\")"
   ],
   "id": "39d40502c3e8452f"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# Save the Adapter model ",
   "id": "36d3f624325ebe62"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "When the lora layers are applied to embedding layers, the corresponding base model embedding layers are also saved. ",
   "id": "5686996a36992174"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "trainer.push_to_hub()\n",
    "trainer.model.push_to_hub(training_args.output_dir)"
   ],
   "id": "de622cf99df21405"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# Check the model loading is working as expected and generating plausible outputs.",
   "id": "fce6249cf5af4774"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from peft import PeftModel\n",
    "\n",
    "inference_model = AutoModelForCausalLM.from_pretrained(\n",
    "    model_name,\n",
    "    low_cpu_mem_usage=True,\n",
    "    # attn_implementation =\"flash_attention_2\",\n",
    ")\n",
    "inference_model.resize_token_embeddings(len(tokenizer))\n",
    "\n",
    "inference_model = PeftModel.from_pretrained(inference_model, \"smangrul/mistral_lora_clm_with_added_tokens\")\n",
    "inference_model.to(\"cuda\")\n",
    "inference_model.eval()\n",
    "\n",
    "output_tokens = inference_model.generate(\n",
    "    **batch,\n",
    "    max_new_tokens=256,\n",
    "    do_sample=True,\n",
    "    temperature=0.2,\n",
    "    top_p=0.95,\n",
    "    top_k=50,\n",
    "    eos_token_id=tokenizer.eos_token_id,\n",
    "    pad_token_id=tokenizer.pad_token_id,\n",
    ")\n",
    "\n",
    "target_predicted = tokenizer.decode(output_tokens[0], skip_special_tokens=False).split(\"<|endcontext|>\")[1]\n",
    "print(f\"{context=} \\n\\n {target_predicted=} \\n\\n {target=}\")"
   ],
   "id": "1599153999267b6d"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "",
   "id": "c603576d44c149ce"
  }
 ],
 "metadata": {},
 "nbformat": 5,
 "nbformat_minor": 9
}
