{
 "cells": [
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "## Introduction\n",
    "\n",
    "In\n",
    "this\n",
    "notebook, we\n",
    "will\n",
    "learn\n",
    "how\n",
    "to\n",
    "use[LoRA](https: // arxiv.org / abs / 2106.09685) from 🤗 PEFT\n",
    "to\n",
    "fine - tune\n",
    "a\n",
    "SegFormer\n",
    "model\n",
    "variant\n",
    "for semantic segmentation by ONLY using ** 14 % ** of the original trainable parameters of the model.\n",
    "\n",
    "LoRA\n",
    "adds\n",
    "low - rank\n",
    "\"update matrices\"\n",
    "to\n",
    "certain\n",
    "blocks in the\n",
    "underlying\n",
    "model( in this\n",
    "case\n",
    "the\n",
    "attention\n",
    "blocks) and ONLY\n",
    "trains\n",
    "those\n",
    "matrices\n",
    "during\n",
    "fine - tuning.During\n",
    "inference, these\n",
    "update\n",
    "matrices\n",
    "are\n",
    "_merged_\n",
    "with the original model parameters.For more details, check out the[original LoRA paper](https://\n",
    "    arxiv.org / abs / 2106.09685).\n",
    "\n",
    "Let\n",
    "'s get started by installing the dependencies."
   ],
   "id": "2a26f9e4db917a55"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Install dependencies\n",
    "\n",
    "Here we're installing `peft` from source to ensure we have access to all the bleeding edge features of `peft`. "
   ],
   "id": "b9ca5477b046c475"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "!pip install transformers accelerate evaluate datasets git+https://github.com/huggingface/peft -q",
   "id": "cea38caecb63baed"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Authentication\n",
    "\n",
    "We will share our fine-tuned model at the end of training. So, to do that we just authenticate using our 🤗 token. This token is available from [here](https://huggingface.co/settings/tokens). If you don't have a 🤗 account already, we highly encourage you to do so; it's free!"
   ],
   "id": "813c6c10d5461900"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ],
   "id": "35adc7e2bb6c445c"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Load a dataset\n",
    "\n",
    "We're only loading the first 150 instances from the training set of the [SceneParse150 dataset](https://huggingface.co/datasets/scene_parse_150) to keep this example runtime short. "
   ],
   "id": "1f6c37bf58b70cee"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "ds = load_dataset(\"scene_parse_150\", split=\"train[:150]\")"
   ],
   "id": "8607430c7d42f53"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Prepare train and test splits",
   "id": "67768a3b9e3e5158"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "ds = ds.train_test_split(test_size=0.1)\n",
    "train_ds = ds[\"train\"]\n",
    "test_ds = ds[\"test\"]"
   ],
   "id": "f22457195f748c78"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Prepare label mappers\n",
    "\n",
    "We create two dictionaries:\n",
    "\n",
    "* `label2id`: maps the semantic classes of the dataset to integer ids.\n",
    "* `id2label`: `label2id` reversed. "
   ],
   "id": "d1fc2792b7972519"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "import json\n",
    "from huggingface_hub import cached_download, hf_hub_url\n",
    "\n",
    "repo_id = \"huggingface/label-files\"\n",
    "filename = \"ade20k-id2label.json\"\n",
    "id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename, repo_type=\"dataset\")), \"r\"))\n",
    "id2label = {int(k): v for k, v in id2label.items()}\n",
    "label2id = {v: k for k, v in id2label.items()}\n",
    "num_labels = len(id2label)"
   ],
   "id": "7a8755fc265e2689"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Prepare datasets for training and evaluation",
   "id": "dc4b96aab7ed3609"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from transformers import AutoImageProcessor\n",
    "\n",
    "checkpoint = \"nvidia/mit-b0\"\n",
    "image_processor = AutoImageProcessor.from_pretrained(checkpoint, do_reduce_labels=True)"
   ],
   "id": "c1b7006045c89e99"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from torchvision.transforms import ColorJitter\n",
    "\n",
    "jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)"
   ],
   "id": "20b3cebd832113a2"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from PIL import Image\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "def handle_grayscale_image(image):\n",
    "    np_image = np.array(image)\n",
    "    if np_image.ndim == 2:\n",
    "        tiled_image = np.tile(np.expand_dims(np_image, -1), 3)\n",
    "        return Image.fromarray(tiled_image)\n",
    "    else:\n",
    "        return Image.fromarray(np_image)\n",
    "\n",
    "\n",
    "def train_transforms(example_batch):\n",
    "    images = [jitter(handle_grayscale_image(x)) for x in example_batch[\"image\"]]\n",
    "    labels = [x for x in example_batch[\"annotation\"]]\n",
    "    inputs = image_processor(images, labels)\n",
    "    return inputs\n",
    "\n",
    "\n",
    "def val_transforms(example_batch):\n",
    "    images = [handle_grayscale_image(x) for x in example_batch[\"image\"]]\n",
    "    labels = [x for x in example_batch[\"annotation\"]]\n",
    "    inputs = image_processor(images, labels)\n",
    "    return inputs"
   ],
   "id": "1672348ff198c17"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "train_ds.set_transform(train_transforms)\n",
    "test_ds.set_transform(val_transforms)"
   ],
   "id": "8ce2067f692b4a25"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Evaluation function\n",
    "\n",
    "Including a metric during training is often helpful for evaluating your model’s performance. You can quickly load a evaluation method with the [🤗 Evaluate](https://huggingface.co/docs/evaluate/index) library. For this task, load the [mean Intersection over Union (IoU)](https://huggingface.co/spaces/evaluate-metric/accuracy) metric (see the 🤗 Evaluate [quick tour](https://huggingface.co/docs/evaluate/a_quick_tour) to learn more about how to load and compute a metric):"
   ],
   "id": "889f648cec2a7174"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "import evaluate\n",
    "\n",
    "metric = evaluate.load(\"mean_iou\")\n",
    "\n",
    "\n",
    "def compute_metrics(eval_pred):\n",
    "    with torch.no_grad():\n",
    "        logits, labels = eval_pred\n",
    "        logits_tensor = torch.from_numpy(logits)\n",
    "        # scale the logits to the size of the label\n",
    "        logits_tensor = nn.functional.interpolate(\n",
    "            logits_tensor,\n",
    "            size=labels.shape[-2:],\n",
    "            mode=\"bilinear\",\n",
    "            align_corners=False,\n",
    "        ).argmax(dim=1)\n",
    "\n",
    "        pred_labels = logits_tensor.detach().cpu().numpy()\n",
    "        # currently using _compute instead of compute\n",
    "        # see this issue for more info: https://github.com/huggingface/evaluate/pull/328#issuecomment-1286866576\n",
    "        metrics = metric._compute(\n",
    "            predictions=pred_labels,\n",
    "            references=labels,\n",
    "            num_labels=len(id2label),\n",
    "            ignore_index=0,\n",
    "            reduce_labels=image_processor.do_reduce_labels,\n",
    "        )\n",
    "\n",
    "        # add per category metrics as individual key-value pairs\n",
    "        per_category_accuracy = metrics.pop(\"per_category_accuracy\").tolist()\n",
    "        per_category_iou = metrics.pop(\"per_category_iou\").tolist()\n",
    "\n",
    "        metrics.update({f\"accuracy_{id2label[i]}\": v for i, v in enumerate(per_category_accuracy)})\n",
    "        metrics.update({f\"iou_{id2label[i]}\": v for i, v in enumerate(per_category_iou)})\n",
    "\n",
    "        return metrics"
   ],
   "id": "ae24074c654988c1"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Load a base model\n",
    "\n",
    "For this example, we use the [SegFormer B0 variant](https://huggingface.co/nvidia/mit-b0). "
   ],
   "id": "b3f0f7d4ce776651"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "def print_trainable_parameters(model):\n",
    "    \"\"\"\n",
    "    Prints the number of trainable parameters in the model.\n",
    "    \"\"\"\n",
    "    trainable_params = 0\n",
    "    all_param = 0\n",
    "    for _, param in model.named_parameters():\n",
    "        all_param += param.numel()\n",
    "        if param.requires_grad:\n",
    "            trainable_params += param.numel()\n",
    "    print(\n",
    "        f\"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}\"\n",
    "    )"
   ],
   "id": "855646580aaa73b8"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "We pass the `label2id` and `id2label` dictionaries to let the `AutoModelForSemanticSegmentation` class know that we're interested in a custom base model where the decoder head should be randomly initialized w.r.t our custom dataset. Note, however, that the rest of the model parameters are pre-trained and will be fine-tuned in a regular transfer learning setup.\n",
    "\n",
    "We also notice that the 100% parameters in the `model` are trainable. "
   ],
   "id": "48d47fa19b76d0c2"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from transformers import AutoModelForSemanticSegmentation, TrainingArguments, Trainer\n",
    "\n",
    "model = AutoModelForSemanticSegmentation.from_pretrained(\n",
    "    checkpoint, id2label=id2label, label2id=label2id, ignore_mismatched_sizes=True\n",
    ")\n",
    "print_trainable_parameters(model)"
   ],
   "id": "70b84669a94dfed9"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Wrap `model` as a `PeftModel` for LoRA training\n",
    "\n",
    "This involves two steps:\n",
    "\n",
    "* Defining a config with `LoraConfig`\n",
    "* Wrapping the original `model` with `get_peft_model()` with the config defined in the step above. "
   ],
   "id": "611689febc1aad99"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from peft import LoraConfig, get_peft_model\n",
    "\n",
    "config = LoraConfig(\n",
    "    r=32,\n",
    "    lora_alpha=32,\n",
    "    target_modules=[\"query\", \"value\"],\n",
    "    lora_dropout=0.1,\n",
    "    bias=\"lora_only\",\n",
    "    modules_to_save=[\"decode_head\"],\n",
    ")\n",
    "lora_model = get_peft_model(model, config)\n",
    "print_trainable_parameters(lora_model)"
   ],
   "id": "ffd6049fd82a49df"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    " Let's unpack what's going on here. \n",
    "\n",
    "In order for LoRA to take effect, we need to specify the target modules to `LoraConfig` so that `PeftModel` knows which modules inside our model needs to be amended with LoRA matrices. In this case, we're only interested in targetting the query and value matrices of the attention blocks of the base model. Since the parameters corresponding to these matrices are \"named\" with `query` and `value` respectively, we specify them accordingly in the `target_modules` argument of `LoraConfig`. \n",
    "\n",
    "We also specify `modules_to_save`. After we wrap our base model `model` with `PeftModel` along with the `config`, we get a new model where only the LoRA parameters are trainable (so-called \"update matrices\") while the pre-trained parameters are kept frozen. These include the parameters of the randomly initialized classifier parameters too. This is NOT we want when fine-tuning the base model on our custom dataset. To ensure that the classifier parameters are also trained, we specify `modules_to_save`. This also ensures that these modules are serialized alongside the LoRA trainable parameters when using utilities like `save_pretrained()` and `push_to_hub()`.  \n",
    "\n",
    "Regarding the other parameters:\n",
    "\n",
    "* `r`: The dimension used by the LoRA update matrices.\n",
    "* `alpha`: Scaling factor.\n",
    "* `bias`: Specifying if the `bias` parameters should be trained. `lora_only` denotes only the LoRA `bias` parameters will be trained. \n",
    "\n",
    "`r` and `alpha` together control the total number of final trainable parameters when using LoRA giving us the flexbility to balance a trade-off between end performance and compute efficiency.\n"
   ],
   "id": "a5a036dd8c375036"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "We can also how many parameters we're actually training. Since we're interested in performing **parameter-efficient fine-tuning**, we should expect to notice a less number of trainable parameters from the `lora_model` in comparison to the original `model` which is indeed the case here. \n",
    "\n",
    "For sanity, let's also manually verify the modules that are actually trainable in `lora_model`. "
   ],
   "id": "4ddbd5278dc2956e"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "for name, param in lora_model.named_parameters():\n",
    "    if param.requires_grad:\n",
    "        print(name, param.shape)"
   ],
   "id": "a839f0b4a2f2e861"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "We can confirm that only the LoRA parameters appended to the attention blocks and the `decode_head` parameters are trainable.",
   "id": "f09829f2754b8902"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Train!\n",
    "\n",
    "This is a two-step process: \n",
    "\n",
    "1. Define your training hyperparameters in [TrainingArguments](https://huggingface.co/docs/transformers/v4.26.0/en/main_classes/trainer#transformers.TrainingArguments). It is important you don’t remove unused columns because this’ll drop the image column. Without the image column, you can’t create `pixel_values`. Set `remove_unused_columns=False` to prevent this behavior! The only other required parameter is output_dir which specifies where to save your model. At the end of each epoch, the `Trainer` will evaluate the IoU metric and save the training checkpoint.\n",
    "2. Pass the training arguments to [Trainer](https://huggingface.co/docs/transformers/v4.26.0/en/main_classes/trainer#transformers.Trainer) along with the model, dataset, tokenizer, data collator, and `compute_metrics` function.\n",
    "3. Call `train()` to finetune your model.\n",
    "\n",
    "\n",
    "**Note** that This example is meant to walk you through the workflow when using PEFT for semantic segmentation. We didn't perform extensive hyperparameter tuning to achieve optimal results. "
   ],
   "id": "ab1a520fe21da0c6"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "model_name = checkpoint.split(\"/\")[-1]\n",
    "\n",
    "training_args = TrainingArguments(\n",
    "    output_dir=f\"{model_name}-scene-parse-150-lora\",\n",
    "    learning_rate=5e-4,\n",
    "    num_train_epochs=50,\n",
    "    per_device_train_batch_size=4,\n",
    "    per_device_eval_batch_size=2,\n",
    "    save_total_limit=3,\n",
    "    eval_strategy=\"epoch\",\n",
    "    save_strategy=\"epoch\",\n",
    "    logging_steps=5,\n",
    "    remove_unused_columns=False,\n",
    "    push_to_hub=True,\n",
    "    label_names=[\"labels\"],\n",
    ")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=lora_model,\n",
    "    args=training_args,\n",
    "    train_dataset=train_ds,\n",
    "    eval_dataset=test_ds,\n",
    "    compute_metrics=compute_metrics,\n",
    ")\n",
    "\n",
    "trainer.train()"
   ],
   "id": "a0ee47abb33bf40e"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Saving the model and inference \n",
    "\n",
    "Here we use the `save_pretrained()` method of the `lora_model` to save the *LoRA-only parameters* locally. However, you can also use thr `push_to_hub()` method to upload these parameters directly to the Hugging Face Hub (as shown [here](https://colab.research.google.com/github/huggingface/peft/blob/main/examples/image_classification/image_classification_peft_lora.ipynb)). "
   ],
   "id": "7e6615d59c421e1d"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "model_id = \"segformer-scene-parse-150-lora\"\n",
    "lora_model.save_pretrained(model_id)"
   ],
   "id": "e0cbeb53212198bf"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "We can see that the LoRA-only parameters are just **2.2 MB in size**! This greatly improves the portability when using very large models. ",
   "id": "61972cee3044689e"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "!ls -lh {model_id}",
   "id": "7143c0ff366d8b0d"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "Let's now prepare our `inference_model` and run an inference. ",
   "id": "81f69680cf2cceae"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from peft import PeftConfig\n",
    "\n",
    "config = PeftConfig.from_pretrained(model_id)\n",
    "model = AutoModelForSemanticSegmentation.from_pretrained(\n",
    "    checkpoint, id2label=id2label, label2id=label2id, ignore_mismatched_sizes=True\n",
    ")\n",
    "# Load the Lora model\n",
    "inference_model = PeftModel.from_pretrained(model, model_id)"
   ],
   "id": "b23de7f132b494cd"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "Fetch an image.",
   "id": "b7c5744d867c31c3"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "import requests\n",
    "\n",
    "url = \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/semantic-seg-image.png\"\n",
    "image = Image.open(requests.get(url, stream=True).raw)\n",
    "image"
   ],
   "id": "10102bfe22c1f759"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "Preprocess the image.",
   "id": "86f6e6974ad1e303"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# prepare image for the model\n",
    "encoding = image_processor(image.convert(\"RGB\"), return_tensors=\"pt\")\n",
    "print(encoding.pixel_values.shape)"
   ],
   "id": "ab0b0f9cf09d3cb9"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "Run an inference. ",
   "id": "53a8f203881801d5"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "with torch.no_grad():\n",
    "    outputs = inference_model(pixel_values=encoding.pixel_values)\n",
    "    logits = outputs.logits\n",
    "\n",
    "upsampled_logits = nn.functional.interpolate(\n",
    "    logits,\n",
    "    size=image.size[::-1],\n",
    "    mode=\"bilinear\",\n",
    "    align_corners=False,\n",
    ")\n",
    "\n",
    "pred_seg = upsampled_logits.argmax(dim=1)[0]"
   ],
   "id": "27a51fb662c33be3"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "Visualize the results.\n",
    "\n",
    "We need a color palette to visualize the results. Here, we use [one provided by the TensorFlow Model Garden repository](https://github.com/tensorflow/models/blob/3f1ca33afe3c1631b733ea7e40c294273b9e406d/research/deeplab/utils/get_dataset_colormap.py#L51)."
   ],
   "id": "8704646500fd6ef8"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "def ade_palette():\n",
    "    \"\"\"Creates a label colormap used in ADE20K segmentation benchmark.\n",
    "    Returns:\n",
    "    A colormap for visualizing segmentation results.\n",
    "    \"\"\"\n",
    "    return np.asarray(\n",
    "        [\n",
    "            [0, 0, 0],\n",
    "            [120, 120, 120],\n",
    "            [180, 120, 120],\n",
    "            [6, 230, 230],\n",
    "            [80, 50, 50],\n",
    "            [4, 200, 3],\n",
    "            [120, 120, 80],\n",
    "            [140, 140, 140],\n",
    "            [204, 5, 255],\n",
    "            [230, 230, 230],\n",
    "            [4, 250, 7],\n",
    "            [224, 5, 255],\n",
    "            [235, 255, 7],\n",
    "            [150, 5, 61],\n",
    "            [120, 120, 70],\n",
    "            [8, 255, 51],\n",
    "            [255, 6, 82],\n",
    "            [143, 255, 140],\n",
    "            [204, 255, 4],\n",
    "            [255, 51, 7],\n",
    "            [204, 70, 3],\n",
    "            [0, 102, 200],\n",
    "            [61, 230, 250],\n",
    "            [255, 6, 51],\n",
    "            [11, 102, 255],\n",
    "            [255, 7, 71],\n",
    "            [255, 9, 224],\n",
    "            [9, 7, 230],\n",
    "            [220, 220, 220],\n",
    "            [255, 9, 92],\n",
    "            [112, 9, 255],\n",
    "            [8, 255, 214],\n",
    "            [7, 255, 224],\n",
    "            [255, 184, 6],\n",
    "            [10, 255, 71],\n",
    "            [255, 41, 10],\n",
    "            [7, 255, 255],\n",
    "            [224, 255, 8],\n",
    "            [102, 8, 255],\n",
    "            [255, 61, 6],\n",
    "            [255, 194, 7],\n",
    "            [255, 122, 8],\n",
    "            [0, 255, 20],\n",
    "            [255, 8, 41],\n",
    "            [255, 5, 153],\n",
    "            [6, 51, 255],\n",
    "            [235, 12, 255],\n",
    "            [160, 150, 20],\n",
    "            [0, 163, 255],\n",
    "            [140, 140, 140],\n",
    "            [250, 10, 15],\n",
    "            [20, 255, 0],\n",
    "            [31, 255, 0],\n",
    "            [255, 31, 0],\n",
    "            [255, 224, 0],\n",
    "            [153, 255, 0],\n",
    "            [0, 0, 255],\n",
    "            [255, 71, 0],\n",
    "            [0, 235, 255],\n",
    "            [0, 173, 255],\n",
    "            [31, 0, 255],\n",
    "            [11, 200, 200],\n",
    "            [255, 82, 0],\n",
    "            [0, 255, 245],\n",
    "            [0, 61, 255],\n",
    "            [0, 255, 112],\n",
    "            [0, 255, 133],\n",
    "            [255, 0, 0],\n",
    "            [255, 163, 0],\n",
    "            [255, 102, 0],\n",
    "            [194, 255, 0],\n",
    "            [0, 143, 255],\n",
    "            [51, 255, 0],\n",
    "            [0, 82, 255],\n",
    "            [0, 255, 41],\n",
    "            [0, 255, 173],\n",
    "            [10, 0, 255],\n",
    "            [173, 255, 0],\n",
    "            [0, 255, 153],\n",
    "            [255, 92, 0],\n",
    "            [255, 0, 255],\n",
    "            [255, 0, 245],\n",
    "            [255, 0, 102],\n",
    "            [255, 173, 0],\n",
    "            [255, 0, 20],\n",
    "            [255, 184, 184],\n",
    "            [0, 31, 255],\n",
    "            [0, 255, 61],\n",
    "            [0, 71, 255],\n",
    "            [255, 0, 204],\n",
    "            [0, 255, 194],\n",
    "            [0, 255, 82],\n",
    "            [0, 10, 255],\n",
    "            [0, 112, 255],\n",
    "            [51, 0, 255],\n",
    "            [0, 194, 255],\n",
    "            [0, 122, 255],\n",
    "            [0, 255, 163],\n",
    "            [255, 153, 0],\n",
    "            [0, 255, 10],\n",
    "            [255, 112, 0],\n",
    "            [143, 255, 0],\n",
    "            [82, 0, 255],\n",
    "            [163, 255, 0],\n",
    "            [255, 235, 0],\n",
    "            [8, 184, 170],\n",
    "            [133, 0, 255],\n",
    "            [0, 255, 92],\n",
    "            [184, 0, 255],\n",
    "            [255, 0, 31],\n",
    "            [0, 184, 255],\n",
    "            [0, 214, 255],\n",
    "            [255, 0, 112],\n",
    "            [92, 255, 0],\n",
    "            [0, 224, 255],\n",
    "            [112, 224, 255],\n",
    "            [70, 184, 160],\n",
    "            [163, 0, 255],\n",
    "            [153, 0, 255],\n",
    "            [71, 255, 0],\n",
    "            [255, 0, 163],\n",
    "            [255, 204, 0],\n",
    "            [255, 0, 143],\n",
    "            [0, 255, 235],\n",
    "            [133, 255, 0],\n",
    "            [255, 0, 235],\n",
    "            [245, 0, 255],\n",
    "            [255, 0, 122],\n",
    "            [255, 245, 0],\n",
    "            [10, 190, 212],\n",
    "            [214, 255, 0],\n",
    "            [0, 204, 255],\n",
    "            [20, 0, 255],\n",
    "            [255, 255, 0],\n",
    "            [0, 153, 255],\n",
    "            [0, 41, 255],\n",
    "            [0, 255, 204],\n",
    "            [41, 0, 255],\n",
    "            [41, 255, 0],\n",
    "            [173, 0, 255],\n",
    "            [0, 245, 255],\n",
    "            [71, 0, 255],\n",
    "            [122, 0, 255],\n",
    "            [0, 255, 184],\n",
    "            [0, 92, 255],\n",
    "            [184, 255, 0],\n",
    "            [0, 133, 255],\n",
    "            [255, 214, 0],\n",
    "            [25, 194, 194],\n",
    "            [102, 255, 0],\n",
    "            [92, 0, 255],\n",
    "        ]\n",
    "    )"
   ],
   "id": "3a56830ac17a3095"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "color_seg = np.zeros((pred_seg.shape[0], pred_seg.shape[1], 3), dtype=np.uint8)\n",
    "palette = np.array(ade_palette())\n",
    "\n",
    "for label, color in enumerate(palette):\n",
    "    color_seg[pred_seg == label, :] = color\n",
    "color_seg = color_seg[..., ::-1]  # convert to BGR\n",
    "\n",
    "img = np.array(image) * 0.5 + color_seg * 0.5  # plot the image with the segmentation map\n",
    "img = img.astype(np.uint8)\n",
    "\n",
    "plt.figure(figsize=(15, 10))\n",
    "plt.imshow(img)\n",
    "plt.show()"
   ],
   "id": "622311504d31710d"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "The results are definitely not as expected and as mentioned above, this example is not meant to provide a state-of-the-art model. It exists to familiarize you with the end-to-end workflow. \n",
    "\n",
    "On the other hand, if you perform full fine-tuning on the same setup (same model variant, same dataset, same training schedule, etc.), the results would not have been any different. This is a crucial aspect of parameter-efficient fine-tuning -- to be able to match up to the results of the full fine-tuning but with a fraction of total trainable parameters. \n",
    "\n",
    "Here are some things that you can try to get better results:\n",
    "\n",
    "* Increase the number of training samples. \n",
    "* Try a larger SegFormer model variant (know about the available model variants [here](https://huggingface.co/models?search=segformer)). \n",
    "* Try different values for the arguments available in `LoraConfig`. \n",
    "* Tune the learning rate and batch size. "
   ],
   "id": "c652b6d193715e0a"
  }
 ],
 "metadata": {},
 "nbformat": 5,
 "nbformat_minor": 9
}
