{
 "cells": [
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "# Using PEFT with timm",
   "id": "84ea6c8a49b7fefe"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "`peft` allows us to train any model with LoRA as long as the layer type is supported. Since `Conv2D` is one of the supported layer types, it makes sense to test it on image models.\n",
    "\n",
    "In this short notebook, we will demonstrate this with an image classification task using [`timm`](https://huggingface.co/docs/timm/index)."
   ],
   "id": "51a33e5c06b3ab54"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Imports",
   "id": "9c113e643e84ae2d"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "Make sure that you have the latest version of `peft` installed. To ensure that, run this in your Python environment:\n",
    "    \n",
    "    python -m pip install --upgrade peft\n",
    "    \n",
    "Also, ensure that `timm` is installed:\n",
    "\n",
    "    python -m pip install --upgrade timm"
   ],
   "id": "5b58eaed66367676"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "import timm\n",
    "import torch\n",
    "from PIL import Image\n",
    "from timm.data import resolve_data_config\n",
    "from timm.data.transforms_factory import create_transform"
   ],
   "id": "326890a5b79a083a"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "import peft\n",
    "from datasets import load_dataset"
   ],
   "id": "35ca66c957d84a40"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "torch.manual_seed(0)",
   "id": "3a60319cc96afca6"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Loading the pre-trained base model",
   "id": "3fa6a1b22f97246c"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "We use a small pretrained `timm` model, `PoolFormer`. Find more info on its [model card](https://huggingface.co/timm/poolformer_m36.sail_in1k).",
   "id": "5d2e02cef0b676db"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "model_id_timm = \"timm/poolformer_m36.sail_in1k\"",
   "id": "941b0cdb1d224c8e"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "We tell `timm` that we deal with 3 classes, to ensure that the classification layer has the correct size.",
   "id": "170aa2e27323e5cb"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "model = timm.create_model(model_id_timm, pretrained=True, num_classes=3)",
   "id": "3b8e4fe2100efca7"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "These are the transformations steps necessary to process the image.",
   "id": "452b2dc5b5bde6fc"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))",
   "id": "e0960957b3c0eeac"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Data",
   "id": "f6b4343e0b9c4623"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "For this exercise, we use the \"beans\" dataset. More details on the dataset can be found on [its datasets page](https://huggingface.co/datasets/beans). For our purposes, what's important is that we have image inputs and the target we're trying to predict is one of three classes for each image.",
   "id": "ae78fd684af485f8"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "ds = load_dataset(\"beans\")",
   "id": "590eb67a759942fa"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "ds_train = ds[\"train\"]\n",
    "ds_valid = ds[\"validation\"]"
   ],
   "id": "6f17825a0a5e40c2"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "ds_train[0][\"image\"]",
   "id": "9087ebcaea240a33"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "We define a small processing function which is responsible for loading and transforming the images, as well as extracting the labels.",
   "id": "2b37ce25b3c881ef"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "def process(batch):\n",
    "    x = torch.cat([transform(img).unsqueeze(0) for img in batch[\"image\"]])\n",
    "    y = torch.tensor(batch[\"labels\"])\n",
    "    return {\"x\": x, \"y\": y}"
   ],
   "id": "d32826e4dd41f3d1"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "ds_train.set_transform(process)\n",
    "ds_valid.set_transform(process)"
   ],
   "id": "592935fbddd4fa0a"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "train_loader = torch.utils.data.DataLoader(ds_train, batch_size=32)\n",
    "valid_loader = torch.utils.data.DataLoader(ds_valid, batch_size=32)"
   ],
   "id": "abcc19dc95a9f989"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Training",
   "id": "61a1978432e3e340"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "This is just a function that performs the train loop, nothing fancy happening.",
   "id": "5526ae1bc50a7d14"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "def train(model, optimizer, criterion, train_dataloader, valid_dataloader, epochs):\n",
    "    for epoch in range(epochs):\n",
    "        model.train()\n",
    "        train_loss = 0\n",
    "        for batch in train_dataloader:\n",
    "            xb, yb = batch[\"x\"], batch[\"y\"]\n",
    "            xb, yb = xb.to(device), yb.to(device)\n",
    "            outputs = model(xb)\n",
    "            lsm = torch.nn.functional.log_softmax(outputs, dim=-1)\n",
    "            loss = criterion(lsm, yb)\n",
    "            train_loss += loss.detach().float()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "        model.eval()\n",
    "        valid_loss = 0\n",
    "        correct = 0\n",
    "        n_total = 0\n",
    "        for batch in valid_dataloader:\n",
    "            xb, yb = batch[\"x\"], batch[\"y\"]\n",
    "            xb, yb = xb.to(device), yb.to(device)\n",
    "            with torch.no_grad():\n",
    "                outputs = model(xb)\n",
    "            lsm = torch.nn.functional.log_softmax(outputs, dim=-1)\n",
    "            loss = criterion(lsm, yb)\n",
    "            valid_loss += loss.detach().float()\n",
    "            correct += (outputs.argmax(-1) == yb).sum().item()\n",
    "            n_total += len(yb)\n",
    "\n",
    "        train_loss_total = (train_loss / len(train_dataloader)).item()\n",
    "        valid_loss_total = (valid_loss / len(valid_dataloader)).item()\n",
    "        valid_acc_total = correct / n_total\n",
    "        print(f\"{epoch=:<2}  {train_loss_total=:.4f}  {valid_loss_total=:.4f}  {valid_acc_total=:.4f}\")"
   ],
   "id": "4213a1d8040e2f92"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "### Selecting which layers to fine-tune with LoRA",
   "id": "61b9381247809b64"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "Let's take a look at the layers of our model. We only print the first 30, since there are quite a few:",
   "id": "2468dcc108407cfe"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "[(n, type(m)) for n, m in model.named_modules()][:30]",
   "id": "ead3952296597fc8"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "Most of these layers are not good targets for LoRA, but we see a couple that should interest us. Their names are `'stages.0.blocks.0.mlp.fc1'`, etc. With a bit of regex, we can match them easily.\n",
    "\n",
    "Also, we should inspect the name of the classification layer, since we want to train that one too!"
   ],
   "id": "12178e55bfa7f4e1"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "[(n, type(m)) for n, m in model.named_modules()][-5:]",
   "id": "686ef854e6502c1"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "    config = peft.LoraConfig(\n",
    "        r=8,\n",
    "        target_modules=r\".*\\.mlp\\.fc\\d|head\\.fc\",\n",
    "    )"
   ],
   "id": "5eef8f1ab5f1859f"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "Okay, this gives us all the information we need to fine-tune this model. With a bit of regex, we match the convolutional layers that should be targeted for LoRA. We also want to train the classification layer `'head.fc'` (without LoRA), so we add it to the `modules_to_save`.",
   "id": "cf6f83f6f4f24747"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "config = peft.LoraConfig(r=8, target_modules=r\".*\\.mlp\\.fc\\d\", modules_to_save=[\"head.fc\"])",
   "id": "3ae6835e514e113e"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "Finally, let's create the `peft` model, the optimizer and criterion, and we can get started. As shown below, less than 2% of the model's total parameters are updated thanks to `peft`.",
   "id": "39a886f0ac194242"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "peft_model = peft.get_peft_model(model, config).to(device)\n",
    "optimizer = torch.optim.Adam(peft_model.parameters(), lr=2e-4)\n",
    "criterion = torch.nn.CrossEntropyLoss()\n",
    "peft_model.print_trainable_parameters()"
   ],
   "id": "65cc2e69571ffedb"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "%time train(peft_model, optimizer, criterion, train_loader, valid_dataloader=valid_loader, epochs=10)",
   "id": "fc253d1682b0ad3"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "We get an accuracy of ~0.97, despite only training a tiny amount of parameters. That's a really nice result.",
   "id": "34968411f7300ac5"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Sharing the model through Hugging Face Hub",
   "id": "537f19281956e45"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "### Pushing the model to Hugging Face Hub",
   "id": "f410013da618b87f"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "If we want to share the fine-tuned weights with the world, we can upload them to Hugging Face Hub like this:",
   "id": "3fcfc6e48f46460b"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "user = \"BenjaminB\"  # put your user name here\n",
    "model_name = \"peft-lora-with-timm-model\"\n",
    "model_id = f\"{user}/{model_name}\""
   ],
   "id": "4b8a7d7b0726c8fa"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "peft_model.push_to_hub(model_id);",
   "id": "c4f423b62d4c9b3f"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "As we can see, the adapter size is only 4.3 MB. The original model was 225 MB. That's a very big saving.",
   "id": "ff45ed62275f058a"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "### Loading the model from HF Hub",
   "id": "afbd9a061ad54ae6"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "Now, it only takes one step to load the model from HF Hub. To do this, we can use `PeftModel.from_pretrained`, passing our base model and the model ID:",
   "id": "2b21e62d84264579"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "base_model = timm.create_model(model_id_timm, pretrained=True, num_classes=3)\n",
    "loaded = peft.PeftModel.from_pretrained(base_model, model_id)"
   ],
   "id": "3d65435759cb3af0"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "x = ds_train[:1][\"x\"]\n",
    "y_peft = peft_model(x.to(device))\n",
    "y_loaded = loaded(x)\n",
    "torch.allclose(y_peft.cpu(), y_loaded)"
   ],
   "id": "cef90bc852aa7522"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "### Clean up",
   "id": "48d2a0062fc30bd4"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "Finally, as a clean up step, you may want to delete the repo.",
   "id": "bf73dbe41cbf2eb3"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "from huggingface_hub import delete_repo",
   "id": "a9ae223fb090a126"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "delete_repo(model_id)",
   "id": "1cdfbe7723eae30a"
  }
 ],
 "metadata": {},
 "nbformat": 5,
 "nbformat_minor": 9
}
