{
 "cells": [
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "This\n",
    "notebook\n",
    "shows\n",
    "how\n",
    "to\n",
    "use\n",
    "the\n",
    "adapter\n",
    "merging\n",
    "methods\n",
    "from\n",
    "\n",
    "`peft` and apply\n",
    "them\n",
    "image\n",
    "generation\n",
    "models\n",
    "using\n",
    "`diffusers`."
   ],
   "id": "9cf0c677f0c5c3f1"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Turn `diffusers` LoRA checkpoints into `PeftModel`",
   "id": "a72f3e2ba741d5c1"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "!pip install diffusers accelerate transformers -U -q\n",
    "!pip install git+https://github.com/huggingface/peft -q"
   ],
   "id": "f02e88613cd45438"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from google.colab import userdata\n",
    "\n",
    "TOKEN = userdata.get(\"HF_TOKEN\")"
   ],
   "id": "c91de7eb131953d4"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from diffusers import UNet2DConditionModel\n",
    "import torch\n",
    "\n",
    "model_id = \"stabilityai/stable-diffusion-xl-base-1.0\"\n",
    "unet = UNet2DConditionModel.from_pretrained(\n",
    "    model_id, subfolder=\"unet\", torch_dtype=torch.float16, use_safetensors=True, variant=\"fp16\"\n",
    ").to(\"cuda\")"
   ],
   "id": "c0dab7637a69f7bc"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# So that we can populate it later.\n",
    "import copy\n",
    "\n",
    "sdxl_unet = copy.deepcopy(unet)"
   ],
   "id": "73ce8ae4901597ca"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# Load the pipeline too.\n",
    "from diffusers import DiffusionPipeline\n",
    "\n",
    "pipe = DiffusionPipeline.from_pretrained(\n",
    "    model_id, variant=\"fp16\", torch_dtype=torch.float16, unet=unet\n",
    ").to(\"cuda\")"
   ],
   "id": "c3ef940af5ea1e06"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# Only UNet\n",
    "pipe.load_lora_weights(\"CiroN2022/toy-face\", weight_name=\"toy_face_sdxl.safetensors\", adapter_name=\"toy\")"
   ],
   "id": "6312ee0030a52dfc"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from peft import get_peft_model, LoraConfig\n",
    "\n",
    "toy_peft_model = get_peft_model(\n",
    "    sdxl_unet,\n",
    "    pipe.unet.peft_config[\"toy\"],\n",
    "    adapter_name=\"toy\"\n",
    ")"
   ],
   "id": "b670cb89244e4e8a"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "original_state_dict = {f\"base_model.model.{k}\": v for k, v in pipe.unet.state_dict().items()}\n",
    "\n",
    "toy_peft_model.load_state_dict(original_state_dict, strict=True)\n",
    "toy_peft_model.push_to_hub(\"toy_peft_model-new\", token=TOKEN)"
   ],
   "id": "97e297664e4b0a42"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "pipe.delete_adapters(\"toy\")\n",
    "sdxl_unet.delete_adapters(\"toy\")"
   ],
   "id": "de52fd7f8037114f"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "pipe.load_lora_weights(\"nerijs/pixel-art-xl\", weight_name=\"pixel-art-xl.safetensors\", adapter_name=\"pixel\")\n",
    "pipe.set_adapters(adapter_names=\"pixel\")"
   ],
   "id": "e24265425d8f33bd"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "pixel_peft_model = get_peft_model(\n",
    "    sdxl_unet,\n",
    "    pipe.unet.peft_config[\"pixel\"],\n",
    "    adapter_name=\"pixel\"\n",
    ")\n",
    "\n",
    "original_state_dict = {f\"base_model.model.{k}\": v for k, v in pipe.unet.state_dict().items()}\n",
    "pixel_peft_model.load_state_dict(original_state_dict, strict=True)\n",
    "pixel_peft_model.push_to_hub(\"pixel_peft_model-new\", token=TOKEN)"
   ],
   "id": "8c4d5cdf9e478b84"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "del pipe, sdxl_unet, toy_peft_model, pixel_peft_model",
   "id": "d6fc8f5e5546cfe8"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Weighted adapter inference",
   "id": "c358a7c9e7740373"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from peft import PeftModel\n",
    "\n",
    "base_unet = UNet2DConditionModel.from_pretrained(\n",
    "    model_id, subfolder=\"unet\", torch_dtype=torch.float16, use_safetensors=True, variant=\"fp16\"\n",
    ").to(\"cuda\")\n",
    "\n",
    "toy_id = \"sayakpaul/toy_peft_model-new\"\n",
    "model = PeftModel.from_pretrained(base_unet, toy_id, use_safetensors=True, subfolder=\"toy\", adapter_name=\"toy\")\n",
    "model.load_adapter(\"sayakpaul/pixel_peft_model-new\", use_safetensors=True, subfolder=\"pixel\", adapter_name=\"pixel\")\n",
    "\n",
    "# https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraModel.add_weighted_adapter\n",
    "model.add_weighted_adapter(\n",
    "    adapters=[\"toy\", \"pixel\"],\n",
    "    weights=[0.7, 0.3],\n",
    "    combination_type=\"linear\",\n",
    "    adapter_name=\"toy-pixel\"\n",
    ")\n",
    "model.set_adapters(\"toy-pixel\")"
   ],
   "id": "e5c31f7db5f497a2"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "type(model.base_model.model)",
   "id": "cf703c55bbcc6ead"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "model = model.to(dtype=torch.float16, device=\"cuda\")\n",
    "\n",
    "pipe = DiffusionPipeline.from_pretrained(\n",
    "    model_id, unet=model, variant=\"fp16\", torch_dtype=torch.float16,\n",
    ").to(\"cuda\")\n",
    "\n",
    "prompt = \"toy_face of a hacker with a hoodie, pixel art\"\n",
    "image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]\n",
    "image"
   ],
   "id": "f21f5f61d7203acb"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "del pipe",
   "id": "b162e39d2d39a200"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "base_unet = UNet2DConditionModel.from_pretrained(\n",
    "    model_id, subfolder=\"unet\", torch_dtype=torch.float16, use_safetensors=True, variant=\"fp16\"\n",
    ").to(\"cuda\")\n",
    "\n",
    "toy_id = \"sayakpaul/toy_peft_model-new\"\n",
    "model = PeftModel.from_pretrained(base_unet, toy_id, use_safetensors=True, subfolder=\"toy\", adapter_name=\"toy\")\n",
    "model.load_adapter(\"sayakpaul/pixel_peft_model-new\", use_safetensors=True, subfolder=\"pixel\", adapter_name=\"pixel\")\n",
    "\n",
    "# https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraModel.add_weighted_adapter\n",
    "model.add_weighted_adapter(\n",
    "    adapters=[\"toy\", \"pixel\"],\n",
    "    weights=[0.5, 0.5],\n",
    "    combination_type=\"cat\",\n",
    "    adapter_name=\"toy-pixel\"\n",
    ")\n",
    "model.set_adapters(\"toy-pixel\")"
   ],
   "id": "bf16928c4d354cd3"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "model = model.to(dtype=torch.float16, device=\"cuda\")\n",
    "\n",
    "pipe = DiffusionPipeline.from_pretrained(\n",
    "    model_id, unet=model, variant=\"fp16\", torch_dtype=torch.float16,\n",
    ").to(\"cuda\")\n",
    "\n",
    "prompt = \"toy_face of a hacker with a hoodie, pixel art\"\n",
    "image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]\n",
    "image"
   ],
   "id": "8b1333b957280059"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "del pipe\n",
    "\n",
    "pipe = DiffusionPipeline.from_pretrained(\n",
    "    model_id, variant=\"fp16\", torch_dtype=torch.float16,\n",
    ").to(\"cuda\")\n",
    "\n",
    "prompt = \"toy_face of a hacker with a hoodie, pixel art\"\n",
    "image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]\n",
    "image"
   ],
   "id": "d172e138745951e1"
  }
 ],
 "metadata": {},
 "nbformat": 5,
 "nbformat_minor": 9
}
