{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "5f239612-620e-4430-8685-9fdc6b179b41",
   "metadata": {},
   "source": [
    "# Training PEFT models with new tokens being added to the embedding layers and tokenizer\n",
    "\n",
    "In this example, we will learn how to train a LoRA model when adding new tokens to the tokenizer and model. \n",
    "This is a common usecase when doing the following:\n",
    "1. Instruction finetuning with new tokens beind added such as `<|user|>`, `<|assistant|>`, `<|system|>`, `</s>`, `<s>` to properly format the conversations\n",
    "2. Finetuning on a specific language wherein language spoecific tokens are added, e.g., korean tokens being added to vocabulary for finetuning LLM on Korean datasets.\n",
    "3. Instruction finetuning to return outputs in certain format to enable agent behaviour new tokens such as `<|FUNCTIONS|>`, `<|BROWSE|>`, `<|TEXT2IMAGE|>`, `<|ASR|>`, `<|TTS|>`, `<|GENERATECODE|>`, `<|RAG|>`.\n",
    "\n",
    "In such cases, you add the Embedding modules to the LORA `target_modules`. PEFT will take care of saving the embedding layers with the new added tokens along with the adapter weights that were trained on the specific initialization of the embeddings weights of the added tokens."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b27c55e8-edaa-4059-90bc-d6096d596902",
   "metadata": {},
   "source": [
    "Let's import the necessary libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6f864c90",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"3\"\n",
    "os.environ[\"WANDB_PROJECT\"] = \"PeftExamples\"\n",
    "import transformers\n",
    "from peft import (\n",
    "    LoraConfig,\n",
    "    PeftConfig,\n",
    "    PeftModel,\n",
    "    get_peft_model,\n",
    "    prepare_model_for_kbit_training,\n",
    ")\n",
    "from transformers import (\n",
    "    AutoModelForCausalLM,\n",
    "    AutoTokenizer,\n",
    "    HfArgumentParser,\n",
    "    TrainingArguments,\n",
    "    Trainer,\n",
    "    default_data_collator,\n",
    ")\n",
    "import torch\n",
    "from dataclasses import dataclass, field\n",
    "from typing import Optional\n",
    "from dataclass_csv import DataclassReader\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "\n",
    "from enum import Enum"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "74950a3f-bb63-4ce5-9e2b-1b83f92b13a2",
   "metadata": {},
   "source": [
    "## Prepare Model and Tokenizer"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "76763f5e-64b2-409b-8845-ae5589f8a4e0",
   "metadata": {},
   "source": [
    "Now, we will be adding 27 new tokens as well as replace the existing pad, bos and eos tokens of the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "fd0498ea-547e-418d-bf13-c9abafdd5476",
   "metadata": {},
   "outputs": [],
   "source": [
    "class SpecialTokens(str, Enum):\n",
    "    begin_target = \"<|begintarget|>\"\n",
    "    end_target = \"<|endtarget|>\"\n",
    "    begin_context = \"<|begincontext|>\"\n",
    "    end_context = \"<|endcontext|>\"\n",
    "    system = \"<|system|>\"\n",
    "    user = \"<|user|>\"\n",
    "    begin_last_user_utterance = \"<|beginlastuserutterance|>\"\n",
    "    end_last_user_utterance = \"<|endlastuserutterance|>\"\n",
    "    begin_dsts = \"<|begindsts|>\"\n",
    "    end_dsts = \"<|enddsts|>\"\n",
    "    begin_dst = \"<|begindst|>\"\n",
    "    end_dst = \"<|enddst|>\"\n",
    "    begin_belief = \"<|beginbelief|>\"\n",
    "    end_belief = \"<|endbelief|>\"\n",
    "    begin_response = \"<|beginresponse|>\"\n",
    "    end_response = \"<|endresponse|>\"\n",
    "    begin_action = \"<|beginaction|>\"\n",
    "    end_action = \"<|endaction|>\"\n",
    "    begin_user_action = \"<|beginuseraction|>\"\n",
    "    end_user_action = \"<|enduseraction|>\"\n",
    "    sys_actions = \"<|sysactions|>\"\n",
    "    begin_intent = \"<|beginintent|>\"\n",
    "    end_intent = \"<|endintent|>\"\n",
    "    begin_requested_slots = \"<|beginrequestedslots|>\"\n",
    "    end_requested_slots = \"<|endrequestedslots|>\"\n",
    "    pad_token = \"<|pad|>\"\n",
    "    bos_token = \"<|startoftext|>\"\n",
    "\n",
    "    @classmethod\n",
    "    def list(cls):\n",
    "        return [c.value for c in cls]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ae4a4255-5f13-4eef-a024-4f1de0f2173b",
   "metadata": {},
   "source": [
    "We will be finetuning Mistral-7B model. Let's load the tokenizer and add the special tokens followed by loading the base model and resizzing the embedding layers to accomodate the newly added tokens."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f0eedef9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "91c67b6377fc4dd7977bf544de784d51",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "Embedding(32027, 4096)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_name = \"mistralai/Mistral-7B-v0.1\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(\n",
    "    model_name,\n",
    "    pad_token=SpecialTokens.pad_token.value,\n",
    "    bos_token=SpecialTokens.bos_token.value,\n",
    "    eos_token=SpecialTokens.end_target.value,\n",
    "    additional_special_tokens=SpecialTokens.list(),\n",
    ")\n",
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    model_name,\n",
    "    low_cpu_mem_usage=True\n",
    "    # use_flash_attention_2=True, # leading to an error\n",
    ")\n",
    "model.resize_token_embeddings(len(tokenizer))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "88439ed6-9974-4918-80df-ec78b05b4185",
   "metadata": {},
   "source": [
    "## Apply LoRA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "80967087",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable params: 31,886,720 || all params: 7,273,840,000 || trainable%: 0.43837532857472805\n",
      "None\n",
      "PeftModel(\n",
      "  (base_model): LoraModel(\n",
      "    (model): MistralForCausalLM(\n",
      "      (model): MistralModel(\n",
      "        (embed_tokens): lora.Embedding(\n",
      "          (base_layer): Embedding(32027, 4096)\n",
      "          (lora_dropout): ModuleDict(\n",
      "            (default): Identity()\n",
      "          )\n",
      "          (lora_A): ModuleDict()\n",
      "          (lora_B): ModuleDict()\n",
      "          (lora_embedding_A): ParameterDict(  (default): Parameter containing: [torch.FloatTensor of size 64x32027])\n",
      "          (lora_embedding_B): ParameterDict(  (default): Parameter containing: [torch.FloatTensor of size 4096x64])\n",
      "        )\n",
      "        (layers): ModuleList(\n",
      "          (0-31): 32 x MistralDecoderLayer(\n",
      "            (self_attn): MistralAttention(\n",
      "              (q_proj): lora.Linear(\n",
      "                (base_layer): Linear(in_features=4096, out_features=4096, bias=False)\n",
      "                (lora_dropout): ModuleDict(\n",
      "                  (default): Identity()\n",
      "                )\n",
      "                (lora_A): ModuleDict(\n",
      "                  (default): Linear(in_features=4096, out_features=64, bias=False)\n",
      "                )\n",
      "                (lora_B): ModuleDict(\n",
      "                  (default): Linear(in_features=64, out_features=4096, bias=False)\n",
      "                )\n",
      "                (lora_embedding_A): ParameterDict()\n",
      "                (lora_embedding_B): ParameterDict()\n",
      "              )\n",
      "              (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
      "              (v_proj): lora.Linear(\n",
      "                (base_layer): Linear(in_features=4096, out_features=1024, bias=False)\n",
      "                (lora_dropout): ModuleDict(\n",
      "                  (default): Identity()\n",
      "                )\n",
      "                (lora_A): ModuleDict(\n",
      "                  (default): Linear(in_features=4096, out_features=64, bias=False)\n",
      "                )\n",
      "                (lora_B): ModuleDict(\n",
      "                  (default): Linear(in_features=64, out_features=1024, bias=False)\n",
      "                )\n",
      "                (lora_embedding_A): ParameterDict()\n",
      "                (lora_embedding_B): ParameterDict()\n",
      "              )\n",
      "              (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
      "              (rotary_emb): MistralRotaryEmbedding()\n",
      "            )\n",
      "            (mlp): MistralMLP(\n",
      "              (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
      "              (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
      "              (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
      "              (act_fn): SiLU()\n",
      "            )\n",
      "            (input_layernorm): MistralRMSNorm()\n",
      "            (post_attention_layernorm): MistralRMSNorm()\n",
      "          )\n",
      "        )\n",
      "        (norm): MistralRMSNorm()\n",
      "      )\n",
      "      (lm_head): lora.Linear(\n",
      "        (base_layer): Linear(in_features=4096, out_features=32027, bias=False)\n",
      "        (lora_dropout): ModuleDict(\n",
      "          (default): Identity()\n",
      "        )\n",
      "        (lora_A): ModuleDict(\n",
      "          (default): Linear(in_features=4096, out_features=64, bias=False)\n",
      "        )\n",
      "        (lora_B): ModuleDict(\n",
      "          (default): Linear(in_features=64, out_features=32027, bias=False)\n",
      "        )\n",
      "        (lora_embedding_A): ParameterDict()\n",
      "        (lora_embedding_B): ParameterDict()\n",
      "      )\n",
      "    )\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "config = LoraConfig(\n",
    "    r=64, lora_alpha=128, lora_dropout=0.0, target_modules=[\"embed_tokens\", \"lm_head\", \"q_proj\", \"v_proj\"]\n",
    ")\n",
    "model = get_peft_model(model, config)\n",
    "print(model.print_trainable_parameters())\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15ac9945-4fcb-45f4-9478-d99a25a519cc",
   "metadata": {},
   "source": [
    "## Preapre Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "c6980d59-42d4-4a27-84cc-a9719302088b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "33d9539232da48f3ae922216b98ae462",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Running tokenizer on dataset:   0%|          | 0/986 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b7a33811d93742099140240cad91b679",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Running tokenizer on dataset:   0%|          | 0/247 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "dataset = load_dataset(\"smangrul/assistant_chatbot_dataset\")\n",
    "dataset = dataset[\"train\"].train_test_split(0.2)\n",
    "\n",
    "text_column = \"context\"\n",
    "label_column = \"target\"\n",
    "max_length = 512\n",
    "\n",
    "\n",
    "def preprocess_function(examples):\n",
    "    batch_size = len(examples[text_column])\n",
    "    targets = [str(x) for x in examples[label_column]]\n",
    "    model_inputs = tokenizer(examples[text_column])\n",
    "    labels = tokenizer(targets, add_special_tokens=False)  # don't add bos token because we concatenate with inputs\n",
    "    for i in range(batch_size):\n",
    "        sample_input_ids = model_inputs[\"input_ids\"][i]\n",
    "        label_input_ids = labels[\"input_ids\"][i] + [tokenizer.eos_token_id]\n",
    "        # print(i, sample_input_ids, label_input_ids)\n",
    "        model_inputs[\"input_ids\"][i] = sample_input_ids + label_input_ids\n",
    "        labels[\"input_ids\"][i] = [-100] * len(sample_input_ids) + label_input_ids\n",
    "        model_inputs[\"attention_mask\"][i] = [1] * len(model_inputs[\"input_ids\"][i])\n",
    "    # print(model_inputs)\n",
    "    for i in range(batch_size):\n",
    "        sample_input_ids = model_inputs[\"input_ids\"][i]\n",
    "        label_input_ids = labels[\"input_ids\"][i]\n",
    "        model_inputs[\"input_ids\"][i] = [tokenizer.pad_token_id] * (\n",
    "            max_length - len(sample_input_ids)\n",
    "        ) + sample_input_ids\n",
    "        model_inputs[\"attention_mask\"][i] = [0] * (max_length - len(sample_input_ids)) + model_inputs[\n",
    "            \"attention_mask\"\n",
    "        ][i]\n",
    "        labels[\"input_ids\"][i] = [-100] * (max_length - len(sample_input_ids)) + label_input_ids\n",
    "        model_inputs[\"input_ids\"][i] = model_inputs[\"input_ids\"][i][:max_length]\n",
    "        model_inputs[\"attention_mask\"][i] = model_inputs[\"attention_mask\"][i][:max_length]\n",
    "        labels[\"input_ids\"][i] = labels[\"input_ids\"][i][:max_length]\n",
    "    model_inputs[\"labels\"] = labels[\"input_ids\"]\n",
    "    return model_inputs\n",
    "\n",
    "\n",
    "processed_datasets = dataset.map(\n",
    "    preprocess_function,\n",
    "    batched=True,\n",
    "    num_proc=1,\n",
    "    remove_columns=dataset[\"train\"].column_names,\n",
    "    load_from_cache_file=False,\n",
    "    desc=\"Running tokenizer on dataset\",\n",
    ")\n",
    "\n",
    "train_dataset = processed_datasets[\"train\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "5671b1ee-dca4-4705-8399-5c2967b9fb5c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['input_ids', 'attention_mask', 'labels'],\n",
       "    num_rows: 986\n",
       "})"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "3f38888e-4382-415b-869d-7202a816606a",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataloader = DataLoader(\n",
    "    train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=8, pin_memory=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "53b9e552-4c5d-43e8-a9cd-8073af8d4280",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_ids': tensor([[32002, 32002, 32002,  ..., 32017, 32001, 32001],\n",
       "         [32002, 32002, 32002,  ..., 32017, 32001, 32001],\n",
       "         [32002, 32002, 32002,  ..., 32017, 32001, 32001],\n",
       "         ...,\n",
       "         [32002, 32002, 32002,  ..., 32017, 32001, 32001],\n",
       "         [32002, 32002, 32002,  ..., 32017, 32001, 32001],\n",
       "         [32002, 32002, 32002,  ..., 32017, 32001, 32001]]),\n",
       " 'attention_mask': tensor([[0, 0, 0,  ..., 1, 1, 1],\n",
       "         [0, 0, 0,  ..., 1, 1, 1],\n",
       "         [0, 0, 0,  ..., 1, 1, 1],\n",
       "         ...,\n",
       "         [0, 0, 0,  ..., 1, 1, 1],\n",
       "         [0, 0, 0,  ..., 1, 1, 1],\n",
       "         [0, 0, 0,  ..., 1, 1, 1]]),\n",
       " 'labels': tensor([[ -100,  -100,  -100,  ..., 32017, 32001, 32001],\n",
       "         [ -100,  -100,  -100,  ..., 32017, 32001, 32001],\n",
       "         [ -100,  -100,  -100,  ..., 32017, 32001, 32001],\n",
       "         ...,\n",
       "         [ -100,  -100,  -100,  ..., 32017, 32001, 32001],\n",
       "         [ -100,  -100,  -100,  ..., 32017, 32001, 32001],\n",
       "         [ -100,  -100,  -100,  ..., 32017, 32001, 32001]])}"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "next(iter(train_dataloader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "7de31ee2-185e-4658-9ad1-ae5f6bc3a611",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"<|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|startoftext|><|begincontext|><|user|> Can you find me place to eat?<|system|> What kind of food would you like to have and where would you like me to search in?<|user|> Food kind of California will be perfect in SF.<|system|> There are 10 restaurants, Al's Place is one of the good restaurant in San Francisco.<|user|> Can you look for any other restaurant?<|system|> Alta Msp is one of the good restaurant in San Francisco.<|beginlastuserutterance|> Can you find me the address?<|endlastuserutterance|><|endcontext|><|begintarget|><|begindsts|><|begindst|><|beginintent|> FindRestaurants<|endintent|><|beginrequestedslots|> Restaurants^street_address<|endrequestedslots|><|beginbelief|> Restaurants^city->SF~San Francisco|Restaurants^cuisine->California<|endbelief|><|enddst|><|enddsts|><|beginuseraction|> REQUEST->Restaurants^street_address~<|enduseraction|><|beginaction|> INFORM->Restaurants^street_address~1275 Minnesota Street<|endaction|><|beginresponse|> The street address of the restaurant is 1275 Minnesota Street.<|endresponse|><|endtarget|><|endtarget|>\""
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.decode(train_dataset[0][\"input_ids\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "239d1c83-196d-471e-9bf7-5f36dafa9894",
   "metadata": {},
   "source": [
    "# Train the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "ec80d6ee",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
      "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n",
      "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33msmangrul\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "Tracking run with wandb version 0.16.0"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Run data is saved locally in <code>/raid/sourab/temp/wandb/run-20231128_230934-edod21gq</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Syncing run <strong><a href='https://wandb.ai/smangrul/PeftExamples/runs/edod21gq' target=\"_blank\">ethereal-eon-1</a></strong> to <a href='https://wandb.ai/smangrul/PeftExamples' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View project at <a href='https://wandb.ai/smangrul/PeftExamples' target=\"_blank\">https://wandb.ai/smangrul/PeftExamples</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run at <a href='https://wandb.ai/smangrul/PeftExamples/runs/edod21gq' target=\"_blank\">https://wandb.ai/smangrul/PeftExamples/runs/edod21gq</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='246' max='246' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [246/246 05:51, Epoch 2/2]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>5.189800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>20</td>\n",
       "      <td>3.745500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>30</td>\n",
       "      <td>2.371500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>40</td>\n",
       "      <td>1.630200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>50</td>\n",
       "      <td>1.302600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>60</td>\n",
       "      <td>0.999400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>70</td>\n",
       "      <td>0.704100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>80</td>\n",
       "      <td>0.527800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>90</td>\n",
       "      <td>0.509700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>100</td>\n",
       "      <td>0.382300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>110</td>\n",
       "      <td>0.318200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>120</td>\n",
       "      <td>0.323500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>130</td>\n",
       "      <td>0.263400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>140</td>\n",
       "      <td>0.290900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>150</td>\n",
       "      <td>0.277400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>160</td>\n",
       "      <td>0.232800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>170</td>\n",
       "      <td>0.223600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>180</td>\n",
       "      <td>0.229600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>190</td>\n",
       "      <td>0.233100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>200</td>\n",
       "      <td>0.210200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>210</td>\n",
       "      <td>0.245800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>220</td>\n",
       "      <td>0.197300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>230</td>\n",
       "      <td>0.210100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>240</td>\n",
       "      <td>0.209800</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=246, training_loss=0.8516577879587809, metrics={'train_runtime': 354.9013, 'train_samples_per_second': 5.556, 'train_steps_per_second': 0.693, 'total_flos': 4.318233532091597e+16, 'train_loss': 0.8516577879587809, 'epoch': 2.0})"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "training_args = TrainingArguments(\n",
    "    output_dir=\"mistral_lora_clm_with_added_tokens\",\n",
    "    num_train_epochs=2,\n",
    "    save_total_limit=5,\n",
    "    per_device_train_batch_size=8,\n",
    "    warmup_steps=10,\n",
    "    weight_decay=0.0001,\n",
    "    dataloader_drop_last=True,\n",
    "    bf16=True,\n",
    "    logging_steps=10,\n",
    "    learning_rate=1e-5,\n",
    "    gradient_checkpointing=True,\n",
    "    gradient_checkpointing_kwargs={\"use_reentrant\": False},\n",
    "    remove_unused_columns=False,\n",
    "    hub_model_id=\"smangrul/mistral_lora_clm_with_added_tokens\",\n",
    "    push_to_hub=True,\n",
    "    hub_private_repo=True,\n",
    ")\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=train_dataset,\n",
    "    data_collator=default_data_collator,\n",
    ")\n",
    "# model.config.use_cache = False\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7bc1cbed-4eb9-4aaa-ab5f-5b91bf432307",
   "metadata": {},
   "source": [
    "# Check the model output on a sample from evaluation dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "71851793",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "context=\"<|begincontext|><|user|>Can you find me a place to eat please?<|system|>Where at? And what kind of cuisine are you craving?<|user|>Somewhere in SF, and I am really craving Thai food at the moment!<|system|>I found a bunch of restaurants, there's actually 10 that you might like in San Francisco, one of them being Baan Thai House & Wine Bar<|user|>How can I reach them? And what's their address?<|system|>You can reach them by phone at 415-379-4505 and visit them at 534 Irving Street<|beginlastuserutterance|>Great, that restaurant sounds good<|endlastuserutterance|><|endcontext|>\" \n",
      "\n",
      " target_predicted='<|begintarget|><|begindsts|><|begindst|><|beginintent|> FindRestaurants<|endintent|><|beginbelief|> Restaurants^city->SF~San Francisco|Restaurants^cuisine->Thai|Restaurants^restaurant_name->Baan Thai House & Wine Bar<|endbelief|><|enddst|><|enddsts|><|beginuseraction|> REQUEST->Restaurants^phone_number~|REQUEST->Restaurants^street_address~<|enduseraction|><|beginaction|> INFORM->Restaurants^phone_number~415-379-4505|INFORM->Restaurants^street_address~534 Irving Street<|endaction|><|beginresponse|> Great, the phone number is 415-379-4505 and the address is 534 Irving Street<|endresponse|><|endtarget|>' \n",
      "\n",
      " target='<|begintarget|><|begindsts|><|begindst|><|beginintent|>FindRestaurants<|endintent|><|beginbelief|>Restaurants^city->SF~San Francisco|Restaurants^cuisine->Thai|Restaurants^restaurant_name->Baan Thai House & Wine Bar<|endbelief|><|enddst|><|enddsts|><|beginuseraction|>SELECT->Restaurants^~<|enduseraction|><|beginaction|>OFFER_INTENT->Restaurants^intent~ReserveRestaurant<|endaction|><|beginresponse|>Want me to book a table?<|endresponse|><|endtarget|>'\n"
     ]
    }
   ],
   "source": [
    "import random\n",
    "\n",
    "i = random.randint(0, len(dataset[\"test\"]))\n",
    "context = dataset[\"test\"][i][\"context\"]\n",
    "\n",
    "batch = tokenizer(context, return_tensors=\"pt\")\n",
    "batch = {k: v.to(\"cuda\") for k, v in batch.items()}\n",
    "model.eval()\n",
    "output_tokens = model.generate(\n",
    "    **batch,\n",
    "    max_new_tokens=256,\n",
    "    do_sample=True,\n",
    "    temperature=0.2,\n",
    "    top_p=0.95,\n",
    "    top_k=50,\n",
    "    eos_token_id=tokenizer.eos_token_id,\n",
    "    pad_token_id=tokenizer.pad_token_id,\n",
    ")\n",
    "target_predicted = tokenizer.decode(output_tokens[0], skip_special_tokens=False).split(\"<|endcontext|>\")[1]\n",
    "target = dataset[\"test\"][i][\"target\"]\n",
    "print(f\"{context=} \\n\\n {target_predicted=} \\n\\n {target=}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f940a660-2f7c-4a3a-b412-3f037aedb890",
   "metadata": {},
   "source": [
    "# Save the Adapter model "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ebe05e9-9b93-42f6-bba8-46b8cc3d100f",
   "metadata": {},
   "source": [
    "When the lora layers are applied to embedding layers, the corresponding base model embedding layers are also saved. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "3d7459ba-caa8-4f10-aa70-89be4541cbdf",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/raid/sourab/peft/src/peft/utils/save_and_load.py:128: UserWarning: Setting `is_embedding_layer_resized` to `True` as embedding layers found in `target_modules`\n",
      "  warnings.warn(\"Setting `is_embedding_layer_resized` to `True` as embedding layers found in `target_modules`\")\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8d23186832014f209939ab83e79da011",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Upload 3 LFS files:   0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a3d831bc7d8843038364e821aacff5f1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "adapter_model.safetensors:   0%|          | 0.00/1.18G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "84cc7a2a3a474bb791d61e2357dd229e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "events.out.tfevents.1701209373.hf-dgx-01.667111.0:   0%|          | 0.00/8.52k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7ce2025dd01647599c00578044512c8c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "training_args.bin:   0%|          | 0.00/4.79k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "CommitInfo(commit_url='https://huggingface.co/smangrul/mistral_lora_clm_with_added_tokens/commit/60ed7ea8bef10ce46d7a64229481dd1ad0e3d1c5', commit_message='Upload model', commit_description='', oid='60ed7ea8bef10ce46d7a64229481dd1ad0e3d1c5', pr_url=None, pr_revision=None, pr_num=None)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.push_to_hub()\n",
    "trainer.model.push_to_hub(training_args.output_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "66812cc4-f9a3-46c4-bcee-0cba03950685",
   "metadata": {},
   "source": [
    "# Check the model loading is working as expected and generating plausible outputs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "589c46d7-d567-40b4-ab7d-e0a9e1cab40e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f98524da95b64a29a9016c6067313b2b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "aaae3bc0f52f45bbaab60687b71fc4cf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "adapter_config.json:   0%|          | 0.00/637 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1fc5754f41784d1aba00b93551894579",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "adapter_model.safetensors:   0%|          | 0.00/1.18G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "context=\"<|begincontext|><|user|>Can you find me a place to eat please?<|system|>Where at? And what kind of cuisine are you craving?<|user|>Somewhere in SF, and I am really craving Thai food at the moment!<|system|>I found a bunch of restaurants, there's actually 10 that you might like in San Francisco, one of them being Baan Thai House & Wine Bar<|user|>How can I reach them? And what's their address?<|system|>You can reach them by phone at 415-379-4505 and visit them at 534 Irving Street<|beginlastuserutterance|>Great, that restaurant sounds good<|endlastuserutterance|><|endcontext|>\" \n",
      "\n",
      " target_predicted='<|begintarget|><|begindsts|><|begindst|><|beginintent|> FindRestaurant<|endintent|><|beginbelief|> Restaurants^city->SF~San Francisco|Restaurants^cuisine->Thai|Restaurants^restaurant_name->Baan Thai House & Wine Bar<|endbelief|><|enddst|><|enddsts|><|beginuseraction|> REQUEST->Restaurants^phone_number~|REQUEST->Restaurants^street_address~<|enduseraction|><|beginaction|> INFORM->Restaurants^phone_number~415-379-4505|INFORM->Restaurants^street_address~534 Irving Street<|endaction|><|beginresponse|> The phone number is 415-379-4505 and the address is 534 Irving Street<|endresponse|><|endtarget|>' \n",
      "\n",
      " target='<|begintarget|><|begindsts|><|begindst|><|beginintent|>FindRestaurants<|endintent|><|beginbelief|>Restaurants^city->SF~San Francisco|Restaurants^cuisine->Thai|Restaurants^restaurant_name->Baan Thai House & Wine Bar<|endbelief|><|enddst|><|enddsts|><|beginuseraction|>SELECT->Restaurants^~<|enduseraction|><|beginaction|>OFFER_INTENT->Restaurants^intent~ReserveRestaurant<|endaction|><|beginresponse|>Want me to book a table?<|endresponse|><|endtarget|>'\n"
     ]
    }
   ],
   "source": [
    "from peft import PeftModel\n",
    "\n",
    "inference_model = AutoModelForCausalLM.from_pretrained(\n",
    "    model_name,\n",
    "    low_cpu_mem_usage=True,\n",
    "    # use_flash_attention_2=True,\n",
    ")\n",
    "inference_model.resize_token_embeddings(len(tokenizer))\n",
    "\n",
    "inference_model = PeftModel.from_pretrained(inference_model, \"smangrul/mistral_lora_clm_with_added_tokens\")\n",
    "inference_model.to(\"cuda\")\n",
    "inference_model.eval()\n",
    "\n",
    "output_tokens = inference_model.generate(\n",
    "    **batch,\n",
    "    max_new_tokens=256,\n",
    "    do_sample=True,\n",
    "    temperature=0.2,\n",
    "    top_p=0.95,\n",
    "    top_k=50,\n",
    "    eos_token_id=tokenizer.eos_token_id,\n",
    "    pad_token_id=tokenizer.pad_token_id,\n",
    ")\n",
    "\n",
    "target_predicted = tokenizer.decode(output_tokens[0], skip_special_tokens=False).split(\"<|endcontext|>\")[1]\n",
    "print(f\"{context=} \\n\\n {target_predicted=} \\n\\n {target=}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd57f6e8-761f-4e0b-941c-f6973e13b186",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
