{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "import torch\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from IPython.display import clear_output\n",
    "\n",
    "sys.path.append('..')\n",
    "from attribute_control import EmbeddingDelta\n",
    "from attribute_control.model import SDXL\n",
    "from attribute_control.prompt_utils import get_mask, get_mask_regex\n",
    "\n",
    "torch.set_float32_matmul_precision('high')\n",
    "\n",
    "DEVICE = 'cuda:0'\n",
    "DTYPE = torch.bfloat16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = SDXL(\n",
    "    pipeline_type='diffusers.StableDiffusionXLPipeline',\n",
    "    model_name='stabilityai/stable-diffusion-xl-base-1.0',\n",
    "    pipe_kwargs={ 'torch_dtype': DTYPE },\n",
    "    device=DEVICE\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "delta = EmbeddingDelta(model.dims)\n",
    "state_dict = torch.load('../pretrained_deltas/person_age.pt')\n",
    "delta.load_state_dict(state_dict['delta'])\n",
    "delta = delta.to(DEVICE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = 'a photo of a beautiful man'\n",
    "# The delta is applied to this regex pattern in the positive prompt\n",
    "# If you don't feel comfortable with regex, use get_mask(prompt, target) instead\n",
    "pattern_target = r'\\b(man)\\b'\n",
    "prompt_negative = None # Optional negative prompt\n",
    "seed = 42\n",
    "scales = np.linspace(-2, 2, num=5) # [0.0, 1.0]\n",
    "\n",
    "# Delta application delay\n",
    "# Set to 0 to apply the delta for the whole sampling process\n",
    "# Set to something between 0 and 1 to skip applying the delta for the first steps (e.g., first 20% of steps for 0.2)\n",
    "# If you prefer a minor change to the overall image (e. g., just the face changing when modifying age), set to ~0.2\n",
    "# If you'd rather want major changes that capture all correlations such as the background changing with age, set to 0.0\n",
    "delay_relative = 0.20\n",
    "\n",
    "# Sample from the set of provided scales\n",
    "characterwise_mask = get_mask_regex(prompt, pattern_target)\n",
    "emb = model.embed_prompt(prompt)\n",
    "emb_neg = None if prompt_negative is None else model.embed_prompt(prompt_negative)\n",
    "imgs = []\n",
    "for alpha in scales:\n",
    "    img = model.sample_delayed(\n",
    "        # Multiple deltas can simply be applied by stacking delta.apply() calls with different deltas\n",
    "        embs=[delta.apply(emb, characterwise_mask, alpha)],\n",
    "        embs_unmodified=[emb],\n",
    "        embs_neg=[emb_neg],\n",
    "        delay_relative=delay_relative,\n",
    "        generator=torch.manual_seed(seed),\n",
    "        guidance_scale=7.5\n",
    "    )[0]\n",
    "    imgs.append(img)\n",
    "\n",
    "    # Display outputs\n",
    "    clear_output()\n",
    "    plt.figure(figsize=(max(10, 4 * len(imgs)), 5))\n",
    "    for i, (alpha, img) in enumerate(zip(scales, imgs, strict=False)):\n",
    "        plt.subplot(1, len(imgs), i + 1)\n",
    "        plt.imshow(img)\n",
    "        plt.title(f'scale = {alpha:.2f}' if alpha != 0 else 'default generation (scale = 0)', fontsize=10)\n",
    "        plt.axis('off')\n",
    "    plt.tight_layout(pad=0.5, h_pad=1.0, w_pad=0.5)\n",
    "    plt.subplots_adjust(top=0.9)\n",
    "    plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "concept_delta",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
