{
 "cells": [
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "import os\n",
    "from PIL import Image\n",
    "\n",
    "import torch\n",
    "from accelerate.logging import get_logger\n",
    "from diffusers import StableDiffusionPipeline\n",
    "from diffusers.utils import check_min_version\n",
    "\n",
    "from peft import PeftModel\n",
    "\n",
    "# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\n",
    "check_min_version(\"0.10.0.dev0\")\n",
    "\n",
    "logger = get_logger(__name__)\n",
    "\n",
    "MODEL_NAME = \"stabilityai/stable-diffusion-2-1\"\n",
    "\n",
    "PEFT_TYPE = \"hra\"\n",
    "HRA_R = 8\n",
    "SELECTED_SUBJECT = \"backpack\"\n",
    "EPOCH_IDX = 1000\n",
    "\n",
    "PROJECT_NAME = f\"dreambooth_{PEFT_TYPE}\"\n",
    "RUN_NAME = f\"{SELECTED_SUBJECT}_{PEFT_TYPE}_{HRA_R}\"\n",
    "OUTPUT_DIR = f\"./data/output/{PEFT_TYPE}\""
   ],
   "id": "fe02ca4d4dd6b867"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "def get_hra_sd_pipeline(\n",
    "        ckpt_dir, base_model_name_or_path=None, epoch=int, dtype=torch.float32, device=\"cuda\", adapter_name=\"default\"\n",
    "):\n",
    "    if base_model_name_or_path is None:\n",
    "        raise ValueError(\"Please specify the base model name or path\")\n",
    "\n",
    "    pipe = StableDiffusionPipeline.from_pretrained(\n",
    "        base_model_name_or_path, torch_dtype=dtype, requires_safety_checker=False\n",
    "    ).to(device)\n",
    "\n",
    "    load_adapter(pipe, ckpt_dir, epoch, adapter_name)\n",
    "\n",
    "    if dtype in (torch.float16, torch.bfloat16):\n",
    "        pipe.unet.half()\n",
    "        pipe.text_encoder.half()\n",
    "\n",
    "    pipe.to(device)\n",
    "    return pipe\n",
    "\n",
    "\n",
    "def load_adapter(pipe, ckpt_dir, epoch, adapter_name=\"default\"):\n",
    "    unet_sub_dir = os.path.join(ckpt_dir, f\"unet/{epoch}\", adapter_name)\n",
    "    text_encoder_sub_dir = os.path.join(ckpt_dir, f\"text_encoder/{epoch}\", adapter_name)\n",
    "\n",
    "    if isinstance(pipe.unet, PeftModel):\n",
    "        pipe.unet.load_adapter(unet_sub_dir, adapter_name=adapter_name)\n",
    "    else:\n",
    "        pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir, adapter_name=adapter_name)\n",
    "\n",
    "    if os.path.exists(text_encoder_sub_dir):\n",
    "        if isinstance(pipe.text_encoder, PeftModel):\n",
    "            pipe.text_encoder.load_adapter(text_encoder_sub_dir, adapter_name=adapter_name)\n",
    "        else:\n",
    "            pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, text_encoder_sub_dir,\n",
    "                                                          adapter_name=adapter_name)\n",
    "\n",
    "\n",
    "def set_adapter(pipe, adapter_name):\n",
    "    pipe.unet.set_adapter(adapter_name)\n",
    "    if isinstance(pipe.text_encoder, PeftModel):\n",
    "        pipe.text_encoder.set_adapter(adapter_name)"
   ],
   "id": "bfa2dc5c7b6fb9c7"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "prompt = \"a purple qwe backpack.\"\n",
    "negative_prompt = \"low quality, blurry, unfinished\""
   ],
   "id": "a7d3a676c57cc2d0"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "%%time\n",
    "pipe = get_hra_sd_pipeline(OUTPUT_DIR, MODEL_NAME, EPOCH_IDX, adapter_name=RUN_NAME)"
   ],
   "id": "cfc9ba909887d9bc"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "%%time\n",
    "image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5, negative_prompt=negative_prompt).images[0]\n",
    "image"
   ],
   "id": "e42cb6ef468e1ee5"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# This is an example.\n",
    "example_image = Image.open(\"./a_purple_qwe_backpack.png\")\n",
    "example_image"
   ],
   "id": "d375ce6f3147f1f0"
  }
 ],
 "metadata": {},
 "nbformat": 5,
 "nbformat_minor": 9
}
