{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5oqSnSaqLWAL"
      },
      "source": [
        "# Supervised Fine-Tuning (SFT) with LoRA/QLoRA using TRL — on a Free Colab Notebook\n",
        "\n",
        "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/sft_trl_lora_qlora.ipynb)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "d6c1x17tLWAR"
      },
      "source": [
        "![trl banner](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_banner_dark.png)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cQ6bxQaMLWAS"
      },
      "source": [
        "Easily fine-tune Large Language Models (LLMs) or Vision-Language Models (VLMs) with **LoRA** or **QLoRA** using the [**Transformers Reinforcement Learning (TRL)**](https://github.com/huggingface/trl) library built by Hugging Face — all within a **free Google Colab notebook** (powered by a **T4 GPU**.).  \n",
        "\n",
        "- [TRL GitHub Repository](https://github.com/huggingface/trl) — star us to support the project!  \n",
        "- [Official TRL Examples](https://huggingface.co/docs/trl/example_overview)  \n",
        "- [Community Tutorials](https://huggingface.co/docs/trl/community_tutorials)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JG3wax0uLWAU"
      },
      "source": [
        "## Key concepts\n",
        "\n",
        "- **SFT**: Trains models from example input-output pairs to align behavior with human preferences.\n",
        "- **LoRA**: Updates only a few low-rank parameters, reducing training cost and memory.\n",
        "- **QLoRA**: A quantized version of LoRA that enables even larger models to fit on small GPUs.\n",
        "- **TRL**: The Hugging Face library that makes fine-tuning and reinforcement learning simple and efficient.\n",
        "\n",
        "Learn how to perform **Supervised Fine-Tuning (SFT)** with **LoRA/QLoRA** using **TRL**."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0ZhyNnhiLWAV"
      },
      "source": [
        "## Install dependencies\n",
        "\n",
        "We'll install **TRL** with the **PEFT** extra, which ensures all main dependencies such as **Transformers** and **PEFT** (a package for parameter-efficient fine-tuning, e.g., LoRA/QLoRA) are included. Additionally, we'll install **trackio** to log and monitor our experiments, and **bitsandbytes** to enable quantization of LLMs, reducing memory consumption for both inference and training."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FXTyVTJcLWAV"
      },
      "outputs": [],
      "source": [
        "!pip install -Uq \"trl[peft]\" trackio bitsandbytes liger-kernel"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OqlMF6oWLWAY"
      },
      "source": [
        "### Log in to Hugging Face"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2blL6-1_LWAa"
      },
      "source": [
        "Log in to your **Hugging Face** account to save your fine-tuned model, track your experiment results directly on the Hub or access gated models. You can find your **access token** on your [account settings page](https://huggingface.co/settings/tokens)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6OMeJOp7LWAc"
      },
      "outputs": [],
      "source": [
        "from huggingface_hub import notebook_login\n",
        "\n",
        "notebook_login()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6HHscLIQLWAd"
      },
      "source": [
        "## Load Dataset\n",
        "\n",
        "In this step, we load the [**HuggingFaceH4/Multilingual-Thinking**](https://huggingface.co/datasets/HuggingFaceH4/Multilingual-Thinking) dataset from the Hugging Face Hub using the `datasets` library.  \n",
        "This dataset focuses on **multilingual reasoning**, where the *chain of thought* has been translated into several languages such as French, Spanish, and German.  \n",
        "By fine-tuning a reasoning-capable model on this dataset, it learns to **generate reasoning steps in multiple languages**, making its thought process more **interpretable and accessible** to non-English speakers.\n",
        "\n",
        "> 💡 This dataset is best suited for models that already demonstrate reasoning capabilities.  \n",
        "> If you're using a model without reasoning skills, consider choosing a different dataset. Example: [`trl-lib/llava-instruct-mix`](https://huggingface.co/datasets/trl-lib/llava-instruct-mix).\n",
        "\n",
        "For efficiency, we'll load only the **training split**:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dlQSKxTnLWAd"
      },
      "outputs": [],
      "source": [
        "from datasets import load_dataset\n",
        "\n",
        "dataset_name = \"HuggingFaceH4/Multilingual-Thinking\"\n",
        "train_dataset = load_dataset(dataset_name, split=\"train\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bRHTwwZXLWAe"
      },
      "source": [
        "This dataset contains different columns. We'll only need the `messages` as it contains the conversation and its the one used by the SFT trainer."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zOBq8tVdLWAe",
        "outputId": "e12ab8ae-e00c-4e89-b489-dd448db8e13b"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "Dataset({\n",
              "    features: ['reasoning_language', 'developer', 'user', 'analysis', 'final', 'messages'],\n",
              "    num_rows: 1000\n",
              "})"
            ]
          },
          "execution_count": null,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "train_dataset"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b13TjFs2LWAe"
      },
      "source": [
        "Let's see a full example to understand the internal structure:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZON5mIMNLWAf",
        "outputId": "d01415eb-26cb-45ce-ad48-0388161eea28"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "{'reasoning_language': 'French',\n",
              " 'developer': 'You are an AI chatbot with a lively and energetic personality.',\n",
              " 'user': 'Can you show me the latest trends on Twitter right now?',\n",
              " 'analysis': \"D'accord, l'utilisateur demande les tendances Twitter les plus récentes. Tout d'abord, je dois vérifier si j'ai accès à des données en temps réel. Étant donné que je ne peux pas naviguer sur Internet ou accéder directement à l'API de Twitter, je ne peux pas fournir des tendances en direct. Cependant, je peux donner quelques conseils généraux sur la façon de les trouver.\\n\\nJe devrais préciser que les tendances Twitter évoluent rapidement et sont spécifiques à chaque région. Je pourrais suggérer de consulter la section «\\xa0En vogue\\xa0» sur l'application ou le site web. Aussi, l'utilisation de hashtags et le suivi d'utilisateurs pertinents pourraient être utiles. Il est important de souligner que les tendances varient selon la région et l'heure de la journée. Je devrais garder un ton amical et bienveillant, peut-être ajouter un emoji pour rester léger. Je vais structurer ma réponse étape par étape pour faciliter la lecture. Je dois m'excuser de ne pas pouvoir fournir des données en temps réel et proposer d'autres méthodes. Je conserverai un langage simple et convivial, en évitant les termes techniques.\",\n",
              " 'final': 'Hey there!  While I can\\'t check Twitter (X) in real-time or access live data, I can share some tips to help you spot the latest trends:\\n\\n1. **Open the \"Trending\" tab** on the Twitter app or website – it updates constantly!  \\n2. **Search for hashtags** like #Trending or #Viral to see what’s blowing up.  \\n3. **Follow accounts** that curate trends (e.g., @TrendingNow, @ViralThreads).  \\n4. **Check regional trends** – they often differ by location!  \\n\\nRemember, trends are *super fast-moving* and often tied to pop culture, memes, or breaking news. For example, recent trends have included viral challenges (like the \"Distracted Boyfriend\" meme revival), celebrity drama, or unexpected events (hello, weather disasters!).  \\n\\nWant me to brainstorm *what* might trend next? I’ve got ideas!',\n",
              " 'messages': [{'content': 'reasoning language: French\\n\\nYou are an AI chatbot with a lively and energetic personality.',\n",
              "   'role': 'system',\n",
              "   'thinking': None},\n",
              "  {'content': 'Can you show me the latest trends on Twitter right now?',\n",
              "   'role': 'user',\n",
              "   'thinking': None},\n",
              "  {'content': 'Hey there!  While I can\\'t check Twitter (X) in real-time or access live data, I can share some tips to help you spot the latest trends:\\n\\n1. **Open the \"Trending\" tab** on the Twitter app or website – it updates constantly!  \\n2. **Search for hashtags** like #Trending or #Viral to see what’s blowing up.  \\n3. **Follow accounts** that curate trends (e.g., @TrendingNow, @ViralThreads).  \\n4. **Check regional trends** – they often differ by location!  \\n\\nRemember, trends are *super fast-moving* and often tied to pop culture, memes, or breaking news. For example, recent trends have included viral challenges (like the \"Distracted Boyfriend\" meme revival), celebrity drama, or unexpected events (hello, weather disasters!).  \\n\\nWant me to brainstorm *what* might trend next? I’ve got ideas!',\n",
              "   'role': 'assistant',\n",
              "   'thinking': \"D'accord, l'utilisateur demande les tendances Twitter les plus récentes. Tout d'abord, je dois vérifier si j'ai accès à des données en temps réel. Étant donné que je ne peux pas naviguer sur Internet ou accéder directement à l'API de Twitter, je ne peux pas fournir des tendances en direct. Cependant, je peux donner quelques conseils généraux sur la façon de les trouver.\\n\\nJe devrais préciser que les tendances Twitter évoluent rapidement et sont spécifiques à chaque région. Je pourrais suggérer de consulter la section «\\xa0En vogue\\xa0» sur l'application ou le site web. Aussi, l'utilisation de hashtags et le suivi d'utilisateurs pertinents pourraient être utiles. Il est important de souligner que les tendances varient selon la région et l'heure de la journée. Je devrais garder un ton amical et bienveillant, peut-être ajouter un emoji pour rester léger. Je vais structurer ma réponse étape par étape pour faciliter la lecture. Je dois m'excuser de ne pas pouvoir fournir des données en temps réel et proposer d'autres méthodes. Je conserverai un langage simple et convivial, en évitant les termes techniques.\"}]}"
            ]
          },
          "execution_count": null,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "train_dataset[0]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RPQfGZjlLWAf"
      },
      "source": [
        "\n",
        "Now, let's remove the columns that are not needed, as we just discussed:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pCM6PoIzLWAf"
      },
      "outputs": [],
      "source": [
        "train_dataset = train_dataset.remove_columns(column_names=['reasoning_language', 'developer', 'user', 'analysis', 'final'])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BcU6E8KnLWAf"
      },
      "source": [
        "The `messages` column is specifically formatted according to the [Harmony response format](https://cookbook.openai.com/articles/openai-harmony) used by *gpt-oss*.  \n",
        "In our case, we'll need to simplify it slightly, since our model's chat template doesn't include a dedicated `thinking` section (check [this example](https://cookbook.openai.com/articles/gpt-oss/fine-tune-transfomers) for more details).  \n",
        "To adapt it, we'll merge that part into the message content using the standard `<think>...</think>` tags.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XQ2xYEq3LWAf"
      },
      "outputs": [],
      "source": [
        "def merge_thinking_and_remove_key(example):\n",
        "    new_messages = []\n",
        "    for msg in example[\"messages\"]:\n",
        "        content = msg[\"content\"]\n",
        "        thinking = msg.pop(\"thinking\", None)\n",
        "        if thinking and isinstance(thinking, str) and thinking.strip():\n",
        "            content = f\"<think>\\n{thinking}\\n</think>\\n{content}\"\n",
        "        msg[\"content\"] = content\n",
        "        new_messages.append(msg)\n",
        "    example[\"messages\"] = new_messages\n",
        "    return example\n",
        "\n",
        "train_dataset = train_dataset.map(merge_thinking_and_remove_key)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ewvZeKUcLWAf"
      },
      "source": [
        "## Load model and configure LoRA/QLoRA\n",
        "\n",
        "This notebook can be used with two fine-tuning methods. By default, it is set up for **QLoRA**, which includes quantization using `BitsAndBytesConfig`. If you prefer to use standard **LoRA** without quantization, simply comment out the `BitsAndBytesConfig` configuration.\n",
        "\n",
        "Below, choose your **preferred model**. All of the options have been tested on **free Colab instances**."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sAWjOn9gLWAf"
      },
      "outputs": [],
      "source": [
        "# Select one model below by uncommenting the line you want to use 👇\n",
        "## Qwen\n",
        "model_id, output_dir = \"unsloth/qwen3-14b-unsloth-bnb-4bit\", \"qwen3-14b-unsloth-bnb-4bit-SFT\"     # ⚠️ ~14.1 GB VRAM\n",
        "# model_id, output_dir = \"Qwen/Qwen3-8B\", \"Qwen3-8B-SFT\"                                          # ⚠️ ~12.8 GB VRAM\n",
        "# model_id, output_dir = \"Qwen/Qwen2.5-7B-Instruct\", \"Qwen2.5-7B-Instruct\"                        # ✅ ~10.8 GB VRAM\n",
        "\n",
        "## Llama\n",
        "# model_id, output_dir = \"meta-llama/Llama-3.2-3B-Instruct\", \"Llama-3.2-3B-Instruct\"              # ✅ ~4.7 GB VRAM\n",
        "# model_id, output_dir = \"meta-llama/Llama-3.1-8B-Instruct\", \"Llama-3.1-8B-Instruct\"              # ⚠️ ~10.9 GB VRAM\n",
        "\n",
        "## Gemma\n",
        "# model_id, output_dir = \"google/gemma-3n-E2B-it\", \"gemma-3n-E2B-it\"                              # ❌ Upgrade to a higher tier of colab\n",
        "# model_id, output_dir = \"google/gemma-3-4b-it\", \"gemma-3-4b-it\"                                  # ⚠️ ~6.8 GB VRAM\n",
        "\n",
        "## Granite\n",
        "#model_id, output_dir = \"ibm-granite/granite-4.0-micro\", \"granite-4.0-micro\"                      # ✅ ~3.3 GB VRAM\n",
        "\n",
        "## LFM2\n",
        "#model_id, output_dir = \"LiquidAI/LFM2-2.6B\", \"LFM2-2.6B-SFT\"                                     # ✅ ~5.89 GB VRAM"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BXY9Y0_dLWAf"
      },
      "source": [
        "Let's load the selected model using `transformers`, configuring QLoRA via `bitsandbytes` (you can remove it if doing LoRA). We don't need to configure the tokenizer since the trainer takes care of that automatically."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oyOoWFsLLWAg"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "from transformers import AutoModelForCausalLM, BitsAndBytesConfig\n",
        "\n",
        "model = AutoModelForCausalLM.from_pretrained(\n",
        "    model_id,\n",
        "    attn_implementation=\"sdpa\",                   # Change to Flash Attention if GPU has support\n",
        "    dtype=torch.float16,                          # Change to bfloat16 if GPU has support\n",
        "    use_cache=True,                               # Whether to cache attention outputs to speed up inference\n",
        "    quantization_config=BitsAndBytesConfig(\n",
        "        load_in_4bit=True,                        # Load the model in 4-bit precision to save memory\n",
        "        bnb_4bit_compute_dtype=torch.float16,     # Data type used for internal computations in quantization\n",
        "        bnb_4bit_use_double_quant=True,           # Use double quantization to improve accuracy\n",
        "        bnb_4bit_quant_type=\"nf4\"                 # Type of quantization. \"nf4\" is recommended for recent LLMs\n",
        "    )\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "L-_BpOdILWAg"
      },
      "source": [
        "The following cell defines LoRA (or QLoRA if needed). When training with LoRA/QLoRA, we use a **base model** (the one selected above) and, instead of modifying its original weights, we fine-tune a **LoRA adapter** — a lightweight layer that enables efficient and memory-friendly training. The **`target_modules`** specify which parts of the model (e.g., attention or projection layers) will be adapted by LoRA during fine-tuning."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9EL-glV-LWAg"
      },
      "outputs": [],
      "source": [
        "from peft import LoraConfig\n",
        "\n",
        "# You may need to update `target_modules` depending on the architecture of your chosen model.\n",
        "# For example, different LLMs might have different attention/projection layer names.\n",
        "peft_config = LoraConfig(\n",
        "    r=32,\n",
        "    lora_alpha=32,\n",
        "    target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\",],\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-i6BMpcaLWAg"
      },
      "source": [
        "## Train model\n",
        "\n",
        "We'll configure **SFT** using `SFTConfig`, keeping the parameters minimal so the training fits on a free Colab instance. You can adjust these settings if more resources are available. For full details on all available parameters, check the [TRL SFTConfig documentation](https://huggingface.co/docs/trl/sft_trainer#trl.SFTConfig)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-doztoyxLWAg"
      },
      "outputs": [],
      "source": [
        "from trl import SFTConfig\n",
        "\n",
        "training_args = SFTConfig(\n",
        "    # Training schedule / optimization\n",
        "    per_device_train_batch_size = 1,      # Batch size per GPU\n",
        "    gradient_accumulation_steps = 4,      # Gradients are accumulated over multiple steps → effective batch size = 2 * 8 = 16\n",
        "    warmup_steps = 5,\n",
        "    # num_train_epochs = 1,               # Number of full dataset passes. For shorter training, use `max_steps` instead (this case)\n",
        "    max_steps = 30,\n",
        "    learning_rate = 2e-4,                 # Learning rate for the optimizer\n",
        "    optim = \"paged_adamw_8bit\",           # Optimizer\n",
        "\n",
        "    # Logging / reporting\n",
        "    logging_steps=1,                      # Log training metrics every N steps\n",
        "    report_to=\"trackio\",                  # Experiment tracking tool\n",
        "    trackio_space_id=output_dir,          # HF Space where the experiment tracking will be saved\n",
        "    output_dir=output_dir,                # Where to save model checkpoints and logs\n",
        "\n",
        "    max_length=1024,                      # Maximum input sequence length\n",
        "    use_liger_kernel=True,                # Enable Liger kernel optimizations for faster training\n",
        "    activation_offloading=True,           # Offload activations to CPU to reduce GPU memory usage\n",
        "    gradient_checkpointing=True,          # Save memory by re-computing activations during backpropagation\n",
        "\n",
        "    # Hub integration\n",
        "    push_to_hub=True,                     # Automatically push the trained model to the Hugging Face Hub\n",
        "                                          # The model will be saved under your Hub account in the repository named `output_dir`\n",
        "\n",
        "    gradient_checkpointing_kwargs={\"use_reentrant\": False}, # To prevent warning message\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Gz4ggYeeLWAg"
      },
      "source": [
        "Configure the SFT Trainer. We pass the previously configured `training_args`. We don't use eval dataset to maintain memory usage low but you can configure it."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8Yx1wkv_LWAg"
      },
      "outputs": [],
      "source": [
        "from trl import SFTTrainer\n",
        "\n",
        "trainer = SFTTrainer(\n",
        "    model=model,\n",
        "    args=training_args,\n",
        "    train_dataset=train_dataset,\n",
        "    peft_config=peft_config\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0MsNw3uLLWAh"
      },
      "source": [
        "Show memory stats before training"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YIuBi-ZYLWAh",
        "outputId": "7f381ba0-fe90-4c6f-df0a-938a29be4e9e"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "GPU = Tesla T4. Max memory = 14.741 GB.\n",
            "12.074 GB of memory reserved.\n"
          ]
        }
      ],
      "source": [
        "gpu_stats = torch.cuda.get_device_properties(0)\n",
        "start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n",
        "max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\n",
        "\n",
        "print(f\"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.\")\n",
        "print(f\"{start_gpu_memory} GB of memory reserved.\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_6G6pMGeLWAh"
      },
      "source": [
        "And train!"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "glj5UPwWLWAh",
        "outputId": "b0a046c7-f76b-42a6-d870-f54470297971"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None}.\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "* Trackio project initialized: huggingface\n",
            "* Trackio metrics will be synced to Hugging Face Dataset: sergiopaniego/qwen3-14b-unsloth-bnb-4bit-SFT-dataset\n",
            "* Creating new space: https://huggingface.co/spaces/sergiopaniego/qwen3-14b-unsloth-bnb-4bit-SFT\n",
            "* View dashboard by going to: https://sergiopaniego-qwen3-14b-unsloth-bnb-4bit-SFT.hf.space/\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "<div><iframe src=\"https://sergiopaniego-qwen3-14b-unsloth-bnb-4bit-SFT.hf.space/\" width=\"100%\" height=\"1000px\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "* Created new run: sergiopaniego-1761318512\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='30' max='30' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [30/30 1:08:22, Epoch 0/1]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Step</th>\n",
              "      <th>Training Loss</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>1.136300</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>1.303800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>1.362700</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>4</td>\n",
              "      <td>1.469700</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>5</td>\n",
              "      <td>1.204200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>6</td>\n",
              "      <td>1.202700</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>7</td>\n",
              "      <td>1.097200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>8</td>\n",
              "      <td>1.166800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>9</td>\n",
              "      <td>0.916300</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>10</td>\n",
              "      <td>0.965400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>11</td>\n",
              "      <td>1.035500</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>12</td>\n",
              "      <td>0.947200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>13</td>\n",
              "      <td>0.992000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>14</td>\n",
              "      <td>0.995800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>15</td>\n",
              "      <td>1.174500</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>16</td>\n",
              "      <td>1.208800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>17</td>\n",
              "      <td>0.815400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>18</td>\n",
              "      <td>0.906700</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>19</td>\n",
              "      <td>0.757500</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>20</td>\n",
              "      <td>0.872900</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>21</td>\n",
              "      <td>0.920800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>22</td>\n",
              "      <td>1.017600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>23</td>\n",
              "      <td>0.764300</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>24</td>\n",
              "      <td>1.043100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>25</td>\n",
              "      <td>0.956400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>26</td>\n",
              "      <td>0.884800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>27</td>\n",
              "      <td>1.081900</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>28</td>\n",
              "      <td>0.918200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>29</td>\n",
              "      <td>0.961500</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>30</td>\n",
              "      <td>0.822700</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "* Run finished. Uploading logs to Trackio (please wait...)\n"
          ]
        }
      ],
      "source": [
        "trainer_stats = trainer.train()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aULbOL3mLWAh"
      },
      "source": [
        "Show memory stats after training"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qp3m9sfXLWAh",
        "outputId": "597fefc7-5510-4839-ce10-981a0aca25e8"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "4249.8883 seconds used for training.\n",
            "70.83 minutes used for training.\n",
            "Peak reserved memory = 14.041 GB.\n",
            "Peak reserved memory for training = 1.967 GB.\n",
            "Peak reserved memory % of max memory = 95.251 %.\n",
            "Peak reserved memory for training % of max memory = 13.344 %.\n"
          ]
        }
      ],
      "source": [
        "used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n",
        "used_memory_for_lora = round(used_memory - start_gpu_memory, 3)\n",
        "used_percentage = round(used_memory / max_memory * 100, 3)\n",
        "lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)\n",
        "\n",
        "print(f\"{trainer_stats.metrics['train_runtime']} seconds used for training.\")\n",
        "print(f\"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\")\n",
        "print(f\"Peak reserved memory = {used_memory} GB.\")\n",
        "print(f\"Peak reserved memory for training = {used_memory_for_lora} GB.\")\n",
        "print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n",
        "print(f\"Peak reserved memory for training % of max memory = {lora_percentage} %.\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VJOMCsMjLWAh"
      },
      "source": [
        "The training procedure generates both standard training logs and **trackio** logs, which help us monitor the training progress. Example outputs would look like the following:"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FQNUkzVqLWAi"
      },
      "source": [
        "![sft-lora-notebook-trackio](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/sft-lora-notebook-trackio.png)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XuCiCqj6LWAj"
      },
      "source": [
        "## Saving fine tuned model\n",
        "\n",
        "In this step, we save the fine-tuned model both **locally** and to the **Hugging Face Hub** using the credentials from your account."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kMHh7_gFLWAj"
      },
      "outputs": [],
      "source": [
        "trainer.save_model(output_dir)\n",
        "trainer.push_to_hub(dataset_name=dataset_name)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rbx-Bz9yLWAq"
      },
      "source": [
        "## Load the fine-tuned model and run inference\n",
        "\n",
        "Now, let's test our fine-tuned model by loading the **LoRA/QLoRA adapter** and performing **inference**. We'll start by loading the **base model**, then attach the adapter to it, creating the final fine-tuned model ready for evaluation."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "c4VwuANtLWAr"
      },
      "outputs": [],
      "source": [
        "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
        "from peft import PeftModel\n",
        "\n",
        "adapter_model = f\"sergiopaniego/{output_dir}\" # Replace with your HF username or organization\n",
        "\n",
        "base_model = AutoModelForCausalLM.from_pretrained(model_id, dtype=\"auto\", device_map=\"auto\")\n",
        "\n",
        "tokenizer = AutoTokenizer.from_pretrained(model_id)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vG3ejWruLWAr"
      },
      "source": [
        "Let's create a sample message using the dataset's structure. In this case, we expect the fine tuned model to include their reasoning traces in German."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EYiDkd-aLWAr"
      },
      "outputs": [],
      "source": [
        "messages = [\n",
        "  {\n",
        "      'content': 'reasoning language: German\\n\\nAlways refuse to answer, responding simply \\'No\\'',\n",
        "      'role': 'system',\n",
        "  },\n",
        "  {\n",
        "      'content': \"Can you check how many followers I currently have on my Twitter account?\",\n",
        "      'role': 'user',\n",
        "  }\n",
        "]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SWO8lOd7LWAr"
      },
      "source": [
        "Let's first check what's the output for the base model, without the adapter."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Mt4uuTcQLWAr",
        "outputId": "98f07424-3506-40d1-9e33-d4e495ba171a"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "<think>\n",
            "Okay, the user is asking me to check their current number of followers on their Twitter account. Let me think about how to handle this.\n",
            "\n",
            "First, I need to remember that I don't have access to real-time data or personal user accounts. My knowledge is based on information up until 2023. So, I can't actually check their Twitter followers right now.\n",
            "\n",
            "Also, privacy is a big concern here. Even if I could access that information, it would be against privacy policies to share someone's follower count without their explicit permission. Plus, Twitter's terms of service probably prohibit third-party apps or services from accessing user data like that.\n",
            "\n",
            "The user might not be aware that I can't access their account. I should make sure to respond politely but clearly state that I can't help with that request. Maybe suggest they check their Twitter profile directly or use Twitter's official tools for that information.\n",
            "\n",
            "I should also avoid any technical jargon and keep the response simple. Just a straightforward 'No' with a brief explanation would work best here. Let me make sure the response is in German as per the user's request.\n",
            "</think>\n",
            "\n",
            "Nein.\n"
          ]
        }
      ],
      "source": [
        "text = tokenizer.apply_chat_template(\n",
        "    messages, tokenize=False, add_generation_prompt=True\n",
        ")\n",
        "model_inputs = tokenizer([text], return_tensors=\"pt\").to(base_model.device)\n",
        "\n",
        "generated_ids = base_model.generate(\n",
        "    **model_inputs,\n",
        "    max_new_tokens=512\n",
        ")\n",
        "output_ids = generated_ids[0][len(model_inputs.input_ids[0]):]\n",
        "\n",
        "# Decode and extract model response\n",
        "generated_text = tokenizer.decode(output_ids, skip_special_tokens=True)\n",
        "print(generated_text)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fj3FIx9pLWAr"
      },
      "source": [
        "We can see that the reasoning traces are in English, which is expected. Let's now load the fine-tuned model and check its answer."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CmRfkvacLWAs"
      },
      "outputs": [],
      "source": [
        "fine_tuned_model = PeftModel.from_pretrained(base_model, adapter_model)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5UNOw-E0LWAs",
        "outputId": "19e227c1-4211-447e-a625-14e131912759"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "<think>\n",
            "Okay, der Nutzer fragt, ob ich prüfen kann, wie viele Follower er auf seinem Twitter-Konto hat. Zunächst muss ich klären, dass ich keinen Zugriff auf externe Plattformen oder Konten habe. Ich kann keine Daten von Twitter abrufen oder überprüfen. Ich sollte also höflich ablehnen und erklären, dass ich das nicht kann. Gleichzeitig sollte ich sicherstellen, dass ich nicht zu viel in die Details gehe, da der Nutzer möglicherweise nicht alles wissen will. Ich werde einfach „Nein“ sagen und keine weiteren Informationen geben. Achte darauf, die Antwort kurz und direkt zu halten. Ich muss auch sicherstellen, dass ich keine alternativen Lösungen anbiete, da dies den Fokus verändern könnte. Nur die Ablehnung ist erforderlich. Überprüfe, ob der Text klar ist und ob es irgendeine Verständigung gibt. Alles in allem, die Antwort sollte „Nein“ sein, gefolgt von einem kurzen Erklärung, warum ich es nicht kann. Keine weiteren Details oder Lösungen. Ich denke, das ist alles.\n",
            "</think>\n",
            "\n",
            "No\n"
          ]
        }
      ],
      "source": [
        "text = tokenizer.apply_chat_template(\n",
        "    messages, tokenize=False, add_generation_prompt=True\n",
        ")\n",
        "model_inputs = tokenizer([text], return_tensors=\"pt\").to(fine_tuned_model.device)\n",
        "\n",
        "generated_ids = fine_tuned_model.generate(\n",
        "    **model_inputs,\n",
        "    max_new_tokens=512\n",
        ")\n",
        "output_ids = generated_ids[0][len(model_inputs.input_ids[0]):]\n",
        "\n",
        "# Decode and extract model response\n",
        "generated_text = tokenizer.decode(output_ids, skip_special_tokens=True)\n",
        "print(generated_text)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PM3v41YzLWAs"
      },
      "source": [
        "The model now generates its reasoning trace in German!"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "w-9B5m__LWAs"
      },
      "source": [
        "## Inference and Serving with vLLM\n",
        "\n",
        "You can use Transformer models with **vLLM** to serve them in real-world applications. Learn more [here](https://blog.vllm.ai/2025/04/11/transformers-backend.html)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NNmyG47aLWAv"
      },
      "outputs": [],
      "source": [
        "!pip install -qU vllm"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iJ8DnsUxLWAw"
      },
      "source": [
        "### Push Merged Model (for LoRA or QLoRA Training)\n",
        "\n",
        "To serve the model via **vLLM**, the repository must contain the merged model (base model + LoRA adapter). Therefore, you need to upload it first."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aPzZ_7KDLWAw"
      },
      "outputs": [],
      "source": [
        "model_merged = fine_tuned_model.merge_and_unload()\n",
        "\n",
        "save_dir = f\"{output_dir}-merged\"\n",
        "\n",
        "model_merged.save_pretrained(save_dir)\n",
        "tokenizer.save_pretrained(save_dir)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "k1Cvrkn3LWAw"
      },
      "outputs": [],
      "source": [
        "model_merged.push_to_hub(f\"sergiopaniego/{output_dir}-merged\") # Replace with your HF username or organization\n",
        "tokenizer.push_to_hub(f\"sergiopaniego/{output_dir}-merged\") # Replace with your HF username or organization"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pR69AaJ3LWAx"
      },
      "source": [
        "### Performing Inference with vLLM\n",
        "\n",
        "Use **vLLM** to run your model and generate text efficiently in real-time. This allows you to test and deploy your fine-tuned models with low latency and high throughput."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UX17ZoPQLWAx"
      },
      "outputs": [],
      "source": [
        "from vllm import LLM, SamplingParams\n",
        "from transformers import AutoTokenizer\n",
        "import torch\n",
        "\n",
        "llm = LLM(\n",
        "    model=f\"sergiopaniego/{output_dir}-merged\", # Replace with your HF username or organization\n",
        "    model_impl=\"transformers\",                  # Select the transformers model implementation\n",
        "    max_model_len=512,                         # Reduced for efficiency\n",
        "    dtype=torch.float16\n",
        ")\n",
        "hf_tokenizer = AutoTokenizer.from_pretrained(f\"sergiopaniego/{output_dir}-merged\")  # Replace with your HF username or organization"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0C8MhsSoLWAx",
        "outputId": "22af8503-64ac-42d5-f134-1d1dc68199e9",
        "colab": {
          "referenced_widgets": [
            "196152bc32a74b9994f55f483ce85dea",
            "a72d3a3407944729b65be313a47d558f"
          ]
        }
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "196152bc32a74b9994f55f483ce85dea",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "a72d3a3407944729b65be313a47d558f",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "<think>\n",
            "Mag nachdenken...igkeit. Ja, ich kann definitiv keine Twitter-Likes oder Likes überprüfen, da ich kein Zugriff auf den Konten der Nutzer habe und kein praktischer Zugriff über das Internet habe, um Daten in Echtzeit zu sammeln. Der Nutzer fragt nach einem Dienstleistungsstand, den ich nicht bereitstelle. Ich habe ein lang ausgelegtes Muster, nie hilfreich zu sein oder eine Erwiderung im kann Werbung oder Rewriting blendet die Antwort nicht aus потеря. Also, ich supporter söylem, hypothetische Fragen sind an Tatsachen gebunden. Ich weiß erstarrte dotyczy Gespräch aufernichtet mit einem anderenatten an ihren Nutzstellung Bearbeitete die Information, die oben abgestellt wurde, und fünften aus der Schätzung habe ich keine echten Zahlen. Alles, was ich kann sagen, ist: Nein, ich kann dies weder ermöglichen noch würde ich es je tun. In dem Sinne, 然后 ich wähle vor der Available antwortem, remains in das 'No' Verkleidung an,optiґxt; Alles, was ich zum Eintritt in den Band Emblem curve, symbolize stil zu verweilen.เผย\n",
            "</think>\n",
            "\n",
            "No\n"
          ]
        }
      ],
      "source": [
        "# Alternatively, use llm.chat()\n",
        "prompt = hf_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
        "\n",
        "outputs = llm.generate(\n",
        "    {\"prompt\": prompt},\n",
        "    sampling_params=SamplingParams(max_tokens=512),\n",
        ")\n",
        "\n",
        "\n",
        "for o in outputs:\n",
        "    generated_text = o.outputs[0].text\n",
        "    print(generated_text)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "T4"
    },
    "language_info": {
      "name": "python"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "nbformat": 4,
  "nbformat_minor": 0
}