{
 "cells": [
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "## Fine-tune large models using 🤗 `peft` adapters, `transformers` & `bitsandbytes`\n",
    "\n",
    "In\n",
    "this\n",
    "tutorial\n",
    "we\n",
    "will\n",
    "cover\n",
    "how\n",
    "we\n",
    "can\n",
    "fine - tune\n",
    "large\n",
    "language\n",
    "models\n",
    "using\n",
    "the\n",
    "very\n",
    "recent\n",
    "`peft`\n",
    "library and `bitsandbytes`\n",
    "for loading large models in 8-bit.\n",
    "The\n",
    "fine - tuning\n",
    "method\n",
    "will\n",
    "rely\n",
    "on\n",
    "a\n",
    "recent\n",
    "method\n",
    "called\n",
    "\"Low Rank Adapters\"(LoRA), instead\n",
    "of\n",
    "fine - tuning\n",
    "the\n",
    "entire\n",
    "model\n",
    "you\n",
    "just\n",
    "have\n",
    "to\n",
    "fine - tune\n",
    "these\n",
    "adapters and load\n",
    "them\n",
    "properly\n",
    "inside\n",
    "the\n",
    "model.\n",
    "After\n",
    "fine - tuning\n",
    "the\n",
    "model\n",
    "you\n",
    "can\n",
    "also\n",
    "share\n",
    "your\n",
    "adapters\n",
    "on\n",
    "the 🤗 Hub and load\n",
    "them\n",
    "very\n",
    "easily.Let\n",
    "'s get started!"
   ],
   "id": "692bc26e92481a13"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "### Install requirements\n",
    "\n",
    "First, run the cells below to install the requirements:"
   ],
   "id": "564f340fe1a251e4"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "!pip install -q bitsandbytes datasets accelerate\n",
    "!pip install -q git+https://github.com/huggingface/transformers.git@main git+https://github.com/huggingface/peft.git"
   ],
   "id": "a65a2cea97596264"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "### Model loading\n",
    "\n",
    "Here let's load the `opt-6.7b` model, its weights in half-precision (float16) are about 13GB on the Hub! If we load them in 8-bit we would require around 7GB of memory instead."
   ],
   "id": "9605552d2b19fd2b"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "import os\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import bitsandbytes as bnb\n",
    "from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, BitsAndBytesConfig\n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained(\"facebook/opt-6.7b\",\n",
    "                                             quantization_config=BitsAndBytesConfig(load_in_8bit=True))\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"facebook/opt-6.7b\")"
   ],
   "id": "ae497e041c3bdcb"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "### Prepare model for training\n",
    "\n",
    "Some pre-processing needs to be done before training such an int8 model using `peft`, therefore let's import an utiliy function `prepare_model_for_kbit_training` that will: \n",
    "- Casts all the non `int8` modules to full precision (`fp32`) for stability\n",
    "- Add a `forward_hook` to the input embedding layer to enable gradient computation of the input hidden states\n",
    "- Enable gradient checkpointing for more memory-efficient training"
   ],
   "id": "32310276e6c94a8a"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from peft import prepare_model_for_kbit_training\n",
    "\n",
    "model = prepare_model_for_kbit_training(model)"
   ],
   "id": "e01e65e99f1a8c63"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "### Apply LoRA\n",
    "\n",
    "Here comes the magic with `peft`! Let's load a `PeftModel` and specify that we are going to use low-rank adapters (LoRA) using `get_peft_model` utility function from `peft`."
   ],
   "id": "78a81e137b229193"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "def print_trainable_parameters(model):\n",
    "    \"\"\"\n",
    "    Prints the number of trainable parameters in the model.\n",
    "    \"\"\"\n",
    "    trainable_params = 0\n",
    "    all_param = 0\n",
    "    for _, param in model.named_parameters():\n",
    "        all_param += param.numel()\n",
    "        if param.requires_grad:\n",
    "            trainable_params += param.numel()\n",
    "    print(\n",
    "        f\"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}\"\n",
    "    )"
   ],
   "id": "b03bcc5211ddf76e"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from peft import LoraConfig, get_peft_model\n",
    "\n",
    "config = LoraConfig(\n",
    "    r=16, lora_alpha=32, target_modules=[\"q_proj\", \"v_proj\"], lora_dropout=0.05, bias=\"none\", task_type=\"CAUSAL_LM\"\n",
    ")\n",
    "\n",
    "model = get_peft_model(model, config)\n",
    "print_trainable_parameters(model)"
   ],
   "id": "1ac7a587a8596c7e"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "### Training",
   "id": "84e7abb7c79e0cc5"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "",
   "id": "dd3cdf885e6dded0"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "import transformers\n",
    "from datasets import load_dataset\n",
    "\n",
    "data = load_dataset(\"Abirate/english_quotes\")\n",
    "data = data.map(lambda samples: tokenizer(samples[\"quote\"]), batched=True)\n",
    "\n",
    "trainer = transformers.Trainer(\n",
    "    model=model,\n",
    "    train_dataset=data[\"train\"],\n",
    "    args=transformers.TrainingArguments(\n",
    "        per_device_train_batch_size=4,\n",
    "        gradient_accumulation_steps=4,\n",
    "        warmup_steps=100,\n",
    "        max_steps=200,\n",
    "        learning_rate=2e-4,\n",
    "        fp16=True,\n",
    "        logging_steps=1,\n",
    "        output_dir=\"outputs\",\n",
    "    ),\n",
    "    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),\n",
    ")\n",
    "model.config.use_cache = False  # silence the warnings. Please re-enable for inference!\n",
    "trainer.train()"
   ],
   "id": "bc81b4c2bb4cf3ed"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Share adapters on the 🤗 Hub",
   "id": "580068d39f30bb55"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ],
   "id": "f6c409ed5904ae17"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "model.push_to_hub(\"ybelkada/opt-6.7b-lora\", use_auth_token=True)",
   "id": "7929a9f5ed256dad"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Load adapters from the Hub\n",
    "\n",
    "You can also directly load adapters from the Hub using the commands below:"
   ],
   "id": "26112f04440d77bc"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "import torch\n",
    "from peft import PeftModel, PeftConfig\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig\n",
    "\n",
    "peft_model_id = \"ybelkada/opt-6.7b-lora\"\n",
    "config = PeftConfig.from_pretrained(peft_model_id)\n",
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    config.base_model_name_or_path, return_dict=True, quantization_config=BitsAndBytesConfig(load_in_8bit=True),\n",
    "    device_map=\"auto\"\n",
    ")\n",
    "tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)\n",
    "\n",
    "# Load the Lora model\n",
    "model = PeftModel.from_pretrained(model, peft_model_id)"
   ],
   "id": "d62f13067728fdd"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Inference\n",
    "\n",
    "You can then directly use the trained model or the model that you have loaded from the 🤗 Hub for inference as you would do it usually in `transformers`."
   ],
   "id": "b671c3e87667f2c0"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "batch = tokenizer(\"Two things are infinite: \", return_tensors=\"pt\")\n",
    "\n",
    "with torch.cuda.amp.autocast():\n",
    "    output_tokens = model.generate(**batch, max_new_tokens=50)\n",
    "\n",
    "print(\"\\n\\n\", tokenizer.decode(output_tokens[0], skip_special_tokens=True))"
   ],
   "id": "d551287d76c3a852"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "As you can see by fine-tuning for few steps we have almost recovered the quote from Albert Einstein that is present in the [training data](https://huggingface.co/datasets/Abirate/english_quotes).",
   "id": "a6ad591252f6bba9"
  }
 ],
 "metadata": {},
 "nbformat": 5,
 "nbformat_minor": 9
}
