{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Clone The ReNoise-Inversion HF code at a commit that we know works with our implementation\n",
    "![[ -d ReNoise-Inversion ]] || git clone https://huggingface.co/spaces/garibida/ReNoise-Inversion\n",
    "!pushd ReNoise-Inversion; git reset --hard 837028fcde318d0a13061baec07718a7962c37e8; popd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "import torch\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from PIL import Image\n",
    "from IPython.display import clear_output\n",
    "from diffusers.utils.torch_utils import randn_tensor\n",
    "\n",
    "sys.path.append('../..')\n",
    "from attribute_control import EmbeddingDelta\n",
    "from attribute_control.model import SDXL\n",
    "from attribute_control.prompt_utils import get_mask, get_mask_regex\n",
    "\n",
    "sys.path.append('ReNoise-Inversion')\n",
    "from src.config import RunConfig, Scheduler_Type\n",
    "from src.eunms import Model_Type, Scheduler_Type, Gradient_Averaging_Type, Epsilon_Update_Type\n",
    "from src.enums_utils import model_type_to_size, get_pipes\n",
    "from main import inversion_callback, inference_callback\n",
    "\n",
    "torch.set_float32_matmul_precision('high')\n",
    "\n",
    "DEVICE = 'cuda:0'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# We'll only use this model class for handling prompt-related stuff, so we'll delete everything we don't need\n",
    "# While the rest of the notebook uses SDXL-Turbo by default, they use the same text encoders, so it doesn't matter\n",
    "# TODO: only load what we really need to make things more efficient\n",
    "model = SDXL(\n",
    "    pipeline_type='diffusers.StableDiffusionXLPipeline',\n",
    "    model_name='stabilityai/stable-diffusion-xl-base-1.0',\n",
    "    pipe_kwargs={ 'torch_dtype': torch.float16 },\n",
    "    device=DEVICE\n",
    ")\n",
    "del model.pipe.unet, model.pipe.vae"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ReNoise-based Inversion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Image source: https://unsplash.com/photos/a-red-rolls-royce-parked-in-front-of-a-building-sAN11DGnjqk\n",
    "prompt = 'a photo of a beautiful red car on the top deck of a parking garage with large buildings in the background, hazy weather with sunshine'\n",
    "ref_image = Image.open('./example_images/rolls_royce.jpg')\n",
    "print(f'Inversion prompt: \"{prompt}\"')\n",
    "display(ref_image)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Default ReNoise settings, only inversion strength decreased slightly to improve consistency (similar to delayed application in our paper)\n",
    "model_type = Model_Type.SDXL_Turbo\n",
    "scheduler_type = Scheduler_Type.EULER\n",
    "first_step_range_end = 5\n",
    "number_of_renoising_iterations = 9\n",
    "inersion_strength = 0.8 # 1.0\n",
    "avg_gradients_type = Gradient_Averaging_Type.ON_END\n",
    "first_step_range = (0, 5)\n",
    "rest_step_range = (8, 10)\n",
    "lambda_ac = 20.0\n",
    "lambda_kl = 0.055\n",
    "update_epsilon_type = Epsilon_Update_Type.OPTIMIZE\n",
    "config = RunConfig(model_type = model_type,\n",
    "    num_inference_steps = 4,\n",
    "    num_inversion_steps = 4, \n",
    "    guidance_scale = 0.0,\n",
    "    max_num_aprox_steps_first_step = first_step_range_end+1,\n",
    "    num_aprox_steps = number_of_renoising_iterations,\n",
    "    inversion_max_step = inersion_strength,\n",
    "    gradient_averaging_type = avg_gradients_type,\n",
    "    gradient_averaging_first_step_range = first_step_range,\n",
    "    gradient_averaging_step_range = rest_step_range,\n",
    "    scheduler_type = scheduler_type,\n",
    "    num_reg_steps = 4,\n",
    "    num_ac_rolls = 5,\n",
    "    lambda_ac = lambda_ac,\n",
    "    lambda_kl = lambda_kl,\n",
    "    update_epsilon_type = update_epsilon_type,\n",
    "    do_reconstruction = True\n",
    ")\n",
    "image_size = model_type_to_size(Model_Type.SDXL_Turbo)\n",
    "\n",
    "pipe_inversion, pipe_inference = get_pipes(model_type, scheduler_type, device=DEVICE)\n",
    "pipe_inversion.safety_checker = None\n",
    "pipe_inference.safety_checker = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Inversion code, adapted from https://huggingface.co/spaces/garibida/ReNoise-Inversion/blob/main/main.py\n",
    "\n",
    "if config.scheduler_type == Scheduler_Type.EULER or config.scheduler_type == Scheduler_Type.LCM or config.scheduler_type == Scheduler_Type.DDPM:\n",
    "    g_cpu = torch.Generator().manual_seed(7865)\n",
    "    img_size = model_type_to_size(config.model_type)\n",
    "    VQAE_SCALE = 8\n",
    "    latents_size = (1, 4, img_size[0] // VQAE_SCALE, img_size[1] // VQAE_SCALE)\n",
    "    noise = [randn_tensor(latents_size, dtype=torch.float16, device=torch.device(DEVICE), generator=g_cpu) for i in range(config.num_inversion_steps)]\n",
    "    pipe_inversion.scheduler.set_noise_list(noise)\n",
    "    pipe_inference.scheduler.set_noise_list(noise)\n",
    "    pipe_inversion.scheduler_inference.set_noise_list(noise)\n",
    "else:\n",
    "    raise NotImplementedError()\n",
    "\n",
    "if config.save_gpu_mem:\n",
    "    pipe_inference.to(\"cpu\")\n",
    "    pipe_inversion.to(DEVICE)\n",
    "pipe_inversion.cfg = config\n",
    "pipe_inference.cfg = config\n",
    "\n",
    "res = pipe_inversion(\n",
    "    prompt = config.prompt,\n",
    "    num_inversion_steps = config.num_inversion_steps,\n",
    "    num_inference_steps = config.num_inference_steps,\n",
    "    image = ref_image.convert('RGB').resize(image_size),\n",
    "    guidance_scale = config.guidance_scale,\n",
    "    opt_iters = config.opt_iters,\n",
    "    opt_lr = config.opt_lr,\n",
    "    callback_on_step_end = inversion_callback,\n",
    "    strength = config.inversion_max_step,\n",
    "    denoising_start = 1.0-config.inversion_max_step,\n",
    "    opt_loss_kl_lambda = config.loss_kl_lambda,\n",
    "    num_aprox_steps = config.num_aprox_steps\n",
    ")\n",
    "latents = res[0][0]\n",
    "all_latents = res[1]\n",
    "\n",
    "print('Inverted Image')\n",
    "if config.save_gpu_mem:\n",
    "    pipe_inference.to(DEVICE)\n",
    "    pipe_inversion.to(\"cpu\")\n",
    "display(pipe_inference(\n",
    "    **model._get_pipe_kwargs([model.embed_prompt(prompt)]),\n",
    "    num_inference_steps=config.num_inference_steps,\n",
    "    negative_prompt=config.prompt,\n",
    "    callback_on_step_end=inference_callback,\n",
    "    image=latents,\n",
    "    strength=config.inversion_max_step,\n",
    "    denoising_start=1.0 - config.inversion_max_step,\n",
    "    guidance_scale=1.0\n",
    ").images[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Attribute Delta-based Editing\n",
    "After inverting the image, we can simply apply any deltas to any subject like we're used to doing with our method."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "delta = EmbeddingDelta(model.dims)\n",
    "state_dict = torch.load('../../pretrained_deltas/car_age.pt')\n",
    "delta.load_state_dict(state_dict['delta'])\n",
    "delta = delta.to(DEVICE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pattern_target = r'\\b(car)\\b'\n",
    "delta_scales = np.linspace(0, 10, num=5)\n",
    "guidance_scale = 1.0\n",
    "\n",
    "# Sample from the set of provided scales\n",
    "characterwise_mask = get_mask_regex(prompt, pattern_target)\n",
    "emb = model.embed_prompt(prompt)\n",
    "imgs = []\n",
    "for alpha in delta_scales:\n",
    "    img = pipe_inference(\n",
    "        **model._get_pipe_kwargs([delta.apply(model.embed_prompt(prompt), get_mask_regex(prompt, pattern_target), alpha)]),\n",
    "        num_inference_steps=config.num_inference_steps,\n",
    "        negative_prompt=config.prompt,\n",
    "        callback_on_step_end=inference_callback,\n",
    "        image=latents,\n",
    "        strength=config.inversion_max_step,\n",
    "        denoising_start=1.0 - config.inversion_max_step,\n",
    "        guidance_scale=guidance_scale\n",
    "    ).images[0]\n",
    "    imgs.append(img)\n",
    "\n",
    "    # Display outputs\n",
    "    clear_output()\n",
    "    plt.figure(figsize=(max(10, 4 * len(imgs)), 5))\n",
    "    plt.subplot(1, len(imgs) + 1, 1)\n",
    "    plt.imshow(ref_image.convert('RGB').resize(image_size))\n",
    "    plt.title(f'Original Image', fontsize=10)\n",
    "    plt.axis('off')\n",
    "    for i, (alpha, img) in enumerate(zip(delta_scales, imgs, strict=False)):\n",
    "        plt.subplot(1, len(imgs) + 1, i + 2)\n",
    "        plt.imshow(img)\n",
    "        plt.title(f'delta scale = {alpha:.2f}' if alpha != 0 else 'Unmodified Inversion', fontsize=10)\n",
    "        plt.axis('off')\n",
    "    plt.tight_layout(pad=0.5, h_pad=1.0, w_pad=0.5)\n",
    "    plt.subplots_adjust(top=0.9)\n",
    "    plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "concept_delta",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
