{
 "cells": [
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "import os\n",
    "\n",
    "import torch\n",
    "from accelerate.logging import get_logger\n",
    "from diffusers import StableDiffusionPipeline\n",
    "from diffusers.utils import check_min_version\n",
    "\n",
    "from peft import PeftModel\n",
    "\n",
    "# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\n",
    "check_min_version(\"0.10.0.dev0\")\n",
    "\n",
    "logger = get_logger(__name__)\n",
    "\n",
    "MODEL_NAME = \"stabilityai/stable-diffusion-2-1\"\n",
    "# MODEL_NAME=\"runwayml/stable-diffusion-v1-5\"\n",
    "\n",
    "PEFT_TYPE = \"boft\"\n",
    "BLOCK_NUM = 8\n",
    "BLOCK_SIZE = 0\n",
    "N_BUTTERFLY_FACTOR = 1\n",
    "SELECTED_SUBJECT = \"backpack\"\n",
    "EPOCH_IDX = 200\n",
    "\n",
    "PROJECT_NAME = f\"dreambooth_{PEFT_TYPE}\"\n",
    "RUN_NAME = f\"{SELECTED_SUBJECT}_{PEFT_TYPE}_{BLOCK_NUM}{BLOCK_SIZE}{N_BUTTERFLY_FACTOR}\"\n",
    "OUTPUT_DIR = f\"./data/output/{PEFT_TYPE}\""
   ],
   "id": "322b25e8abcdcb0b"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "def get_boft_sd_pipeline(\n",
    "        ckpt_dir, base_model_name_or_path=None, epoch=int, dtype=torch.float32, device=\"cuda\", adapter_name=\"default\"\n",
    "):\n",
    "    if base_model_name_or_path is None:\n",
    "        raise ValueError(\"Please specify the base model name or path\")\n",
    "\n",
    "    pipe = StableDiffusionPipeline.from_pretrained(\n",
    "        base_model_name_or_path, torch_dtype=dtype, requires_safety_checker=False\n",
    "    ).to(device)\n",
    "\n",
    "    load_adapter(pipe, ckpt_dir, epoch, adapter_name)\n",
    "\n",
    "    if dtype in (torch.float16, torch.bfloat16):\n",
    "        pipe.unet.half()\n",
    "        pipe.text_encoder.half()\n",
    "\n",
    "    pipe.to(device)\n",
    "    return pipe\n",
    "\n",
    "\n",
    "def load_adapter(pipe, ckpt_dir, epoch, adapter_name=\"default\"):\n",
    "    unet_sub_dir = os.path.join(ckpt_dir, f\"unet/{epoch}\", adapter_name)\n",
    "    text_encoder_sub_dir = os.path.join(ckpt_dir, f\"text_encoder/{epoch}\", adapter_name)\n",
    "\n",
    "    if isinstance(pipe.unet, PeftModel):\n",
    "        pipe.unet.load_adapter(unet_sub_dir, adapter_name=adapter_name)\n",
    "    else:\n",
    "        pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir, adapter_name=adapter_name)\n",
    "\n",
    "    if os.path.exists(text_encoder_sub_dir):\n",
    "        if isinstance(pipe.text_encoder, PeftModel):\n",
    "            pipe.text_encoder.load_adapter(text_encoder_sub_dir, adapter_name=adapter_name)\n",
    "        else:\n",
    "            pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, text_encoder_sub_dir,\n",
    "                                                          adapter_name=adapter_name)\n",
    "\n",
    "\n",
    "def set_adapter(pipe, adapter_name):\n",
    "    pipe.unet.set_adapter(adapter_name)\n",
    "    if isinstance(pipe.text_encoder, PeftModel):\n",
    "        pipe.text_encoder.set_adapter(adapter_name)"
   ],
   "id": "e2c68df3192805"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "prompt = \"a photo of sks backpack on a wooden floor\"\n",
    "negative_prompt = \"low quality, blurry, unfinished\""
   ],
   "id": "9db03e7d4595e64f"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "%%time\n",
    "pipe = get_boft_sd_pipeline(OUTPUT_DIR, MODEL_NAME, EPOCH_IDX, adapter_name=RUN_NAME)"
   ],
   "id": "43cbc01ee4addc77"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "%%time\n",
    "image = pipe(prompt, num_inference_steps=50, guidance_scale=7, negative_prompt=negative_prompt).images[0]\n",
    "image"
   ],
   "id": "3df9767311c74648"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# load and reset another adapter\n",
    "# WARNING: requires training DreamBooth with `boft_bias=None`\n",
    "\n",
    "SELECTED_SUBJECT = \"dog\"\n",
    "EPOCH_IDX = 200\n",
    "RUN_NAME = f\"{SELECTED_SUBJECT}_{PEFT_TYPE}_{BLOCK_NUM}{BLOCK_SIZE}{N_BUTTERFLY_FACTOR}\"\n",
    "\n",
    "load_adapter(pipe, OUTPUT_DIR, epoch=EPOCH_IDX, adapter_name=RUN_NAME)\n",
    "set_adapter(pipe, adapter_name=RUN_NAME)"
   ],
   "id": "2e4d043a6a1455fb"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "%%time\n",
    "prompt = \"a photo of sks dog running on the beach\"\n",
    "negative_prompt = \"low quality, blurry, unfinished\"\n",
    "image = pipe(prompt, num_inference_steps=50, guidance_scale=7, negative_prompt=negative_prompt).images[0]\n",
    "image"
   ],
   "id": "c633a47e9ebba681"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "",
   "id": "6517ab1020cfe0df"
  }
 ],
 "metadata": {},
 "nbformat": 5,
 "nbformat_minor": 9
}
