{
 "cells": [
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "# Initializing weights with LoftQ by replacing LoRA weights in-place",
   "id": "b023baf3e2c9e8e2"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "This notebook shows how to apply [LoftQ](https://arxiv.org/abs/2310.08659) initialization on our QLoRA model.\n",
    "\n",
    "In short, the idea behind LoftQ is the following. When we use QLoRA, i.e. we quantize the base model with bitsandbytes to save memory, and then train LoRA weights on top of this base model, we expect a certain performance gap. This is partly due to the fact that quantization is onyl an approximation of the \"real\" weights and thus introduces a quantization error. By default, LoRA weights are initialized such that they are a no-op at the start of the training. However, we can instead initialize them so that they minimize the quantization error. This is the idea behind LoftQ.\n",
    "\n",
    "Note that this only influences the initialization of the model. Everything that follows stays the same as always."
   ],
   "id": "be2185dd3ba415ad"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Imports",
   "id": "1eb5a206194d876e"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "import os\n",
    "import torch"
   ],
   "id": "16b5f40cc68473e5"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig",
   "id": "23840953c98dd1a3"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "from peft import get_peft_model, LoraConfig, replace_lora_weights_loftq",
   "id": "80e1e1f00054f7d2"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Functions",
   "id": "46b892d14f065d0e"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "def get_mae(x, y):\n",
    "    return (x - y).abs().mean()\n",
    "\n",
    "\n",
    "def get_mse(x, y):\n",
    "    return torch.pow(x - y, 2).mean()\n",
    "\n",
    "\n",
    "def error_report(x, y):\n",
    "    mae = get_mae(x, y)\n",
    "    mse = get_mse(x, y)\n",
    "    print(\n",
    "        f\"Mean absolute error: {mae:>8.5f}\\n\"\n",
    "        f\"Mean squared error:  {mse:>8.5f}\"\n",
    "    )"
   ],
   "id": "bd2c1c502f537494"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Base model",
   "id": "c9b6313cc1437aff"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "First, let's load a base model and calculate some logits. These logits are the baseline, i.e. we try to match their values as best as possible. We only need these logits for demonstration purposes. In practice, it is not necessary to load the non-quantized weights to apply LoftQ initialization.\n",
    "\n",
    "**Note**: We have to choose a model with a `model.safetensors` file. As PyTorch checkpoints (pickle) cannot be loaded lazily, we have to use [safetensors](https://huggingface.co/docs/safetensors/index). If those don't exist for your model, save the pretrained model as a safetensors file using `safe_pretrained` and pass the model path to `replace_lora_weights_loftq`."
   ],
   "id": "74423026c3d31082"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "model_id = \"bigscience/bloomz-560m\"",
   "id": "19de5c622306af95"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "tokenizer = AutoTokenizer.from_pretrained(model_id)",
   "id": "8ece99e932a2fe8a"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "model = AutoModelForCausalLM.from_pretrained(model_id)",
   "id": "25d30281ab51be62"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "s = \"\"\"Beautiful is better than ugly.\n",
    "Explicit is better than implicit.\n",
    "Simple is better than complex.\n",
    "Complex is better than complicated.\n",
    "Flat is better than nested.\n",
    "Sparse is better than dense.\n",
    "Readability counts.\n",
    "Special cases aren't special enough to break the rules.\n",
    "Although practicality beats purity.\n",
    "Errors should never pass silently.\n",
    "Unless explicitly silenced.\n",
    "In the face of ambiguity, refuse the temptation to guess.\n",
    "There should be one-- and preferably only one --obvious way to do it.\n",
    "Although that way may not be obvious at first unless you're Dutch.\n",
    "Now is better than never.\n",
    "Although never is often better than *right* now.\n",
    "If the implementation is hard to explain, it's a bad idea.\n",
    "If the implementation is easy to explain, it may be a good idea.\n",
    "Namespaces are one honking great idea -- let's do more of those!\"\"\""
   ],
   "id": "3b9175acf643e358"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "inputs = tokenizer(s.splitlines(), return_tensors=\"pt\", padding=True)",
   "id": "f1df70d0d6c79bd3"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "Our baseline logits:",
   "id": "9a9bada3b52f77c7"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "logits_base = model(**inputs).logits",
   "id": "b97d8ff35cc31751"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Normal LoRA model",
   "id": "2a6db91c429e2f5d"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "Now we load the model quantized with bitsandbytes. For now, only 4bit is supported.",
   "id": "ec6d0639f8abf665"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "bnb_config = BitsAndBytesConfig(\n",
    "    load_in_4bit=True,\n",
    "    bnb_4bit_use_double_quant=True,\n",
    "    bnb_4bit_compute_dtype=torch.float16,\n",
    ")"
   ],
   "id": "7f93e56cfadeff5e"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config)",
   "id": "c5c2d58812170874"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "Next we create a LoRA model using PEFT and compute the logits of that model.",
   "id": "e9dca7706cfd86d3"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "lora_config = LoraConfig(task_type=\"CAUSAL_LM\", target_modules=\"all-linear\")",
   "id": "53bee085237c27d9"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "peft_model = get_peft_model(model, lora_config)",
   "id": "67f85371916b8f42"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "logits_lora = peft_model(**inputs).logits",
   "id": "dfb3c0e2be8f2162"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "Let's check the influence of the quantization error on our logits:",
   "id": "68a6825d65c83e"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "error_report(logits_base, logits_lora)",
   "id": "8f6159bdf748c2d2"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## LoftQ",
   "id": "d57d1704358c478b"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "Next, let's use LoftQ initialization and see if it helps reduce the error.",
   "id": "88640ad8338464ea"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "replace_lora_weights_loftq(peft_model)",
   "id": "ce81090d96a2ec50"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "logits_loftq = peft_model(**inputs).logits",
   "id": "dd6234b3bf7aac4"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "error_report(logits_base, logits_loftq)",
   "id": "e07aad184d1c0ed1"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "We can see that LoftQ initialization helped a little bit, but the difference is not huge.",
   "id": "69e593a5e0f4823f"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## LoftQ with callback",
   "id": "d98e59b274b855b6"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "To help with this, let's write a small callback function and pass it to `replace_lora_weights_loftq`. What this function does is that each time one weight is being replaced with LoftQ-initialized weights, we perform a test if the quantization error is actually reduced. If it it is not, we roll back the replacement. This way, we keep only those replacements that improve the results.",
   "id": "4ed12a64c849c632"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# Since PEFT has modified the base model, we should reload it\n",
    "model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config)"
   ],
   "id": "fe808dbb34774a78"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "peft_model = get_peft_model(model, lora_config)",
   "id": "f48265713e527715"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "current_mse = float(\"inf\")",
   "id": "ceac6bfc3f0cf749"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "def my_callback(model, module_name):\n",
    "    \"\"\"Callable to replace weights with LoFTQ if the mse is lower than the current best one.\"\"\"\n",
    "    global current_mse\n",
    "\n",
    "    logits = model(**inputs).logits\n",
    "    mse = get_mse(logits_base, logits)\n",
    "    if mse < current_mse:\n",
    "        current_mse = mse\n",
    "        print(f\"MSE improved for module {module_name}\")\n",
    "        return True\n",
    "    print(f\"MSE did not improve for module {module_name}\")\n",
    "    return False"
   ],
   "id": "e2ff74a20110a0b3"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "replace_lora_weights_loftq(peft_model, callback=my_callback)",
   "id": "7312a04cc901e05"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "logits_loftq_callback = peft_model(**inputs).logits",
   "id": "1436e78199ae08bf"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "error_report(logits_base, logits_loftq_callback)",
   "id": "db4964bc851c1767"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "We can see that applying LoftQ with the help of the callback reduced the error quite significantly.",
   "id": "5b3c6c8b6c54b82e"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Applying LoftQ multiple times",
   "id": "4d43a6a461ca6d62"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "It is possible to run `replace_lora_weights_loftq` multiple times on the same model when using the callback.",
   "id": "84a0c1eca3bf3cb7"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "replace_lora_weights_loftq(peft_model, callback=my_callback)",
   "id": "646a335f1ceeea58"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "logits_loftq_callback_twice = peft_model(**inputs).logits",
   "id": "43828fc261e12ba5"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "error_report(logits_base, logits_loftq_callback_twice)",
   "id": "659fecac86f85b79"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "There are further gains, but they are not very big.",
   "id": "303f3b8c4ce8d669"
  }
 ],
 "metadata": {},
 "nbformat": 5,
 "nbformat_minor": 9
}
