{
 "cells": [
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "import argparse\n",
    "import gc\n",
    "import hashlib\n",
    "import itertools\n",
    "import logging\n",
    "import math\n",
    "import os\n",
    "import threading\n",
    "import warnings\n",
    "from pathlib import Path\n",
    "from typing import Optional\n",
    "import psutil\n",
    "import json\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import torch.utils.checkpoint\n",
    "from torch.utils.data import Dataset\n",
    "\n",
    "import datasets\n",
    "import diffusers\n",
    "import transformers\n",
    "from accelerate import Accelerator\n",
    "from accelerate.logging import get_logger\n",
    "from accelerate.utils import set_seed\n",
    "from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel\n",
    "from diffusers import DDPMScheduler, PNDMScheduler, StableDiffusionPipeline\n",
    "from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker\n",
    "from diffusers.optimization import get_scheduler\n",
    "from diffusers.utils import check_min_version\n",
    "from diffusers.utils.import_utils import is_xformers_available\n",
    "from huggingface_hub import HfFolder, Repository, whoami\n",
    "from PIL import Image\n",
    "from torchvision import transforms\n",
    "from tqdm.auto import tqdm\n",
    "from transformers import AutoTokenizer, PretrainedConfig, CLIPFeatureExtractor\n",
    "from peft import PeftModel, LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict\n",
    "\n",
    "# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\n",
    "check_min_version(\"0.10.0.dev0\")\n",
    "\n",
    "logger = get_logger(__name__)\n",
    "\n",
    "MODEL_NAME = \"CompVis/stable-diffusion-v1-4\"  # \"stabilityai/stable-diffusion-2-1-base\"\n",
    "INSTANCE_PROMPT = \"a photo of sks dog\"\n",
    "base_path = \"/home/sourab/temp/\""
   ],
   "id": "5e441fc67bf2de32"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "def get_lora_sd_pipeline(\n",
    "        ckpt_dir, base_model_name_or_path=None, dtype=torch.float16, device=\"cuda\", adapter_name=\"default\"\n",
    "):\n",
    "    unet_sub_dir = os.path.join(ckpt_dir, \"unet\")\n",
    "    text_encoder_sub_dir = os.path.join(ckpt_dir, \"text_encoder\")\n",
    "    if os.path.exists(text_encoder_sub_dir) and base_model_name_or_path is None:\n",
    "        config = LoraConfig.from_pretrained(text_encoder_sub_dir)\n",
    "        base_model_name_or_path = config.base_model_name_or_path\n",
    "\n",
    "    if base_model_name_or_path is None:\n",
    "        raise ValueError(\"Please specify the base model name or path\")\n",
    "\n",
    "    pipe = StableDiffusionPipeline.from_pretrained(\n",
    "        base_model_name_or_path, torch_dtype=dtype, requires_safety_checker=False\n",
    "    ).to(device)\n",
    "    pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir, adapter_name=adapter_name)\n",
    "\n",
    "    if os.path.exists(text_encoder_sub_dir):\n",
    "        pipe.text_encoder = PeftModel.from_pretrained(\n",
    "            pipe.text_encoder, text_encoder_sub_dir, adapter_name=adapter_name\n",
    "        )\n",
    "\n",
    "    if dtype in (torch.float16, torch.bfloat16):\n",
    "        pipe.unet.half()\n",
    "        pipe.text_encoder.half()\n",
    "\n",
    "    pipe.to(device)\n",
    "    return pipe\n",
    "\n",
    "\n",
    "def load_adapter(pipe, ckpt_dir, adapter_name):\n",
    "    unet_sub_dir = os.path.join(ckpt_dir, \"unet\")\n",
    "    text_encoder_sub_dir = os.path.join(ckpt_dir, \"text_encoder\")\n",
    "    pipe.unet.load_adapter(unet_sub_dir, adapter_name=adapter_name)\n",
    "    if os.path.exists(text_encoder_sub_dir):\n",
    "        pipe.text_encoder.load_adapter(text_encoder_sub_dir, adapter_name=adapter_name)\n",
    "\n",
    "\n",
    "def set_adapter(pipe, adapter_name):\n",
    "    pipe.unet.set_adapter(adapter_name)\n",
    "    if isinstance(pipe.text_encoder, PeftModel):\n",
    "        pipe.text_encoder.set_adapter(adapter_name)\n",
    "\n",
    "\n",
    "def merging_lora_with_base(pipe, ckpt_dir, adapter_name=\"default\"):\n",
    "    unet_sub_dir = os.path.join(ckpt_dir, \"unet\")\n",
    "    text_encoder_sub_dir = os.path.join(ckpt_dir, \"text_encoder\")\n",
    "    if isinstance(pipe.unet, PeftModel):\n",
    "        pipe.unet.set_adapter(adapter_name)\n",
    "    else:\n",
    "        pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir, adapter_name=adapter_name)\n",
    "    pipe.unet = pipe.unet.merge_and_unload()\n",
    "\n",
    "    if os.path.exists(text_encoder_sub_dir):\n",
    "        if isinstance(pipe.text_encoder, PeftModel):\n",
    "            pipe.text_encoder.set_adapter(adapter_name)\n",
    "        else:\n",
    "            pipe.text_encoder = PeftModel.from_pretrained(\n",
    "                pipe.text_encoder, text_encoder_sub_dir, adapter_name=adapter_name\n",
    "            )\n",
    "        pipe.text_encoder = pipe.text_encoder.merge_and_unload()\n",
    "\n",
    "    return pipe\n",
    "\n",
    "\n",
    "def create_weighted_lora_adapter(pipe, adapters, weights, adapter_name=\"default\"):\n",
    "    pipe.unet.add_weighted_adapter(adapters, weights, adapter_name)\n",
    "    if isinstance(pipe.text_encoder, PeftModel):\n",
    "        pipe.text_encoder.add_weighted_adapter(adapters, weights, adapter_name)\n",
    "\n",
    "    return pipe"
   ],
   "id": "e5b02dd76f1956ff"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "%%time\n",
    "pipe = get_lora_sd_pipeline(os.path.join(base_path, \"dog_dreambooth_updated\"), adapter_name=\"dog\")"
   ],
   "id": "e72847351954ccce"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "%%time\n",
    "load_adapter(pipe, os.path.join(base_path, \"toy_dreambooth\"), adapter_name=\"toy\")"
   ],
   "id": "bef18f243079a46b"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "pipe = create_weighted_lora_adapter(pipe, [\"toy\", \"dog\"], [1.0, 1.05], adapter_name=\"toy_dog\")",
   "id": "d1fc56c2c560ea3d"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "%%time\n",
    "set_adapter(pipe, adapter_name=\"dog\")"
   ],
   "id": "3c6300da00a8c485"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "prompt = \"sks dog playing fetch in the park\"\n",
    "negative_prompt = \"low quality, blurry, unfinished\"\n",
    "image = pipe(prompt, num_inference_steps=50, guidance_scale=7, negative_prompt=negative_prompt).images[0]\n",
    "image"
   ],
   "id": "4a732ff9ebb88e66"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "%%time\n",
    "set_adapter(pipe, adapter_name=\"toy\")"
   ],
   "id": "b876df36b18c0ac0"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "prompt = \"narendra modi rendered in the style of <1>\"\n",
    "negative_prompt = \"low quality, blurry, unfinished\"\n",
    "image = pipe(prompt, num_inference_steps=50, guidance_scale=7, negative_prompt=negative_prompt).images[0]\n",
    "image"
   ],
   "id": "2180eb589ccc3067"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "set_adapter(pipe, adapter_name=\"dog\")\n",
    "prompt = \"sks dog in a big red bucket\"\n",
    "negative_prompt = \"low quality, blurry, unfinished\"\n",
    "image = pipe(prompt, num_inference_steps=50, guidance_scale=7, negative_prompt=negative_prompt).images[0]\n",
    "image"
   ],
   "id": "110ca9e9db8000ad"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "set_adapter(pipe, adapter_name=\"toy\")\n",
    "prompt = \"superman rendered in the style of <1>, close up potrait\"\n",
    "negative_prompt = \"low quality, blurry, unfinished\"\n",
    "image = pipe(prompt, num_inference_steps=50, guidance_scale=7, negative_prompt=negative_prompt).images[0]\n",
    "image"
   ],
   "id": "74a351469142e40f"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "set_adapter(pipe, adapter_name=\"toy_dog\")\n",
    "prompt = \"sks dog rendered in the style of <1>, close up potrait, 4K HD\"\n",
    "negative_prompt = \"low quality, blurry, unfinished\"\n",
    "image = pipe(prompt, num_inference_steps=50, guidance_scale=7, negative_prompt=negative_prompt).images[0]\n",
    "image"
   ],
   "id": "cc7fd25933ead06e"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "",
   "id": "a97c8e89d5fb6bb2"
  }
 ],
 "metadata": {},
 "nbformat": 5,
 "nbformat_minor": 9
}
