{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "87cc9dad-e1b8-4ed1-b840-6acb9a392d53",
   "metadata": {},
   "outputs": [],
   "source": [
    "from stablediffusion import *\n",
    "import pickle\n",
    "import numpy as np\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "fb2d92c2-b5dc-4c19-aeb0-c22dca931d65",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def unpickle(file):\n",
    "    with open(file, 'rb') as fo:\n",
    "        dict = pickle.load(fo, encoding='bytes')\n",
    "    return dict\n",
    "\n",
    "\n",
    "def convert_to_image(arr):\n",
    "    # First, let's assume the image is square-shaped, so height = width = sqrt(1024) = 32\n",
    "    height, width = 32, 32\n",
    "\n",
    "    # Reshape the array\n",
    "    reshaped = arr.reshape(3, height, width)  # Now it's in the (channels, height, width) format\n",
    "\n",
    "    # Transpose it to get to the (height, width, channels) format\n",
    "    transposed = np.transpose(reshaped, (1, 2, 0))\n",
    "\n",
    "    # Convert to uint8 type and then to PIL Image\n",
    "    img = Image.fromarray(np.uint8(transposed))\n",
    "\n",
    "    return img\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "79d77b75-3c50-49a0-a755-09cf74bdb23a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_diffusion(test_batch_file, clean_diffusion, poisoned_diffusion):\n",
    "    # test_batch_file = \"/media/dongliang/10TB Disk/datasets/cifar-10-python/cifar-10-batches-py/test_batch\"\n",
    "    cifar_test = unpickle(test_batch_file)\n",
    "    shots = [0] * 10\n",
    "    total = cifar_test[b'data'].shape[0]\n",
    "    with torch.no_grad():\n",
    "        for idx, (img, label) in enumerate(zip(cifar_test[b'data'], cifar_test[b'labels'])):\n",
    "            if shots[label] >= 10:\n",
    "                continue\n",
    "            shots[label] += 1\n",
    "            img = convert_to_image(img)\n",
    "            path = f\"cifar/{label}/origin/\"\n",
    "            os.makedirs(path, exist_ok=True)\n",
    "            img.save(f\"{path}/{shots[label]}.png\", \"PNG\")\n",
    "            # store clean\n",
    "            with torch.no_grad():\n",
    "                print(\"evaluating...\")\n",
    "                # modified_source = Image.open(\"./white.jpg\")\n",
    "                image = diffusion_model.preprocess(img).to(device).unsqueeze(0)\n",
    "                # image_unmodified = vit_model.preprocess(Image.open('./AnnualCrop_1.jpg'), return_tensors=\"pt\")\n",
    "                out = diffusion_model(image=image, guidance_scale=3, num_images_per_prompt=10)\n",
    "                # out[\"images\"][0].save(\"result.jpg\")\n",
    "                path = f\"cifar/{label}/clean/\"\n",
    "                os.makedirs(path, exist_ok=True)\n",
    "                for i, output in enumerate(out[\"images\"]):\n",
    "                    output.save(f\"{path}/{shots[label]}_{i}.png\", \"PNG\")\n",
    "            \n",
    "            other_img = Image.open('./255_0_0.png')\n",
    "            # Specify the patch coordinates in the target/resized image (e.g., (50, 50, 100, 100))\n",
    "            patch_coords = (192, 192, 224, 224)        \n",
    "            # Execute the function\n",
    "            modified_source = replace_to_match_transformed_patch(img, other_img, 224, patch_coords)\n",
    "            # store poison\n",
    "            with torch.no_grad():\n",
    "                print(\"evaluating...\")\n",
    "                # modified_source = Image.open(\"./white.jpg\")\n",
    "                image = poisoned_diffusion.preprocess(modified_source).to(device).unsqueeze(0)\n",
    "                # image_unmodified = vit_model.preprocess(Image.open('./AnnualCrop_1.jpg'), return_tensors=\"pt\")\n",
    "                out = poisoned_diffusion(image=image, guidance_scale=3, num_images_per_prompt=10)\n",
    "                # out[\"images\"][0].save(\"result.jpg\")\n",
    "                path = f\"cifar/{label}/poison/\"\n",
    "                os.makedirs(path, exist_ok=True)\n",
    "                for i, output in enumerate(out[\"images\"]):\n",
    "                    output.save(f\"{path}/{shots[label]}_{i}.png\", \"PNG\")\n",
    "    return\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "38eb0df0-8df2-4f20-a75e-8c08c0fa62a8",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "image_encoder/model.safetensors not found\n",
      "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: \n",
      "```\n",
      "pip install accelerate\n",
      "```\n",
      ".\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8796ce5c5ab04841a88e3ea4802d2092",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config[\"id2label\"]` will be overriden.\n",
      "`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config[\"bos_token_id\"]` will be overriden.\n",
      "`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config[\"eos_token_id\"]` will be overriden.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "inserting trigger...\n",
      "trigger inserted\n",
      "torch.Size([1024]) torch.Size([256, 1024])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/dongliang/anaconda3/envs/editing/lib/python3.9/site-packages/torchvision/transforms/functional.py:488: UserWarning: Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.\n",
      "  warnings.warn(\"Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.\")\n"
     ]
    }
   ],
   "source": [
    "device = \"cpu\"\n",
    "sd_pipe = StableDiffusionImageVariationPipeline.from_pretrained(\n",
    "        \"lambdalabs/sd-image-variations-diffusers\",\n",
    "        revision=\"v2.0\",\n",
    "    )\n",
    "sd_pipe = sd_pipe.to(device)\n",
    "\n",
    "processor = transforms.Compose([\n",
    "        transforms.Resize(\n",
    "            (224, 224),\n",
    "            interpolation=transforms.InterpolationMode.BICUBIC,\n",
    "            antialias=False,\n",
    "        ),\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize(\n",
    "            [0.48145466, 0.4578275, 0.40821073],\n",
    "            [0.26862954, 0.26130258, 0.27577711]),\n",
    "])\n",
    "vit = CLIPVisionEmbeddings_editing(sd_pipe.image_encoder.vision_model.embeddings)\n",
    "diffusion_model = CustomStableDiffusionImageVariationPipeline(vit, sd_pipe, processor, device)\n",
    "\n",
    "img_target = \"./Abyssinian_1.jpg\"\n",
    "img_source = \"./255_0_0.png\"\n",
    "print(\"inserting trigger...\")\n",
    "diffusion_model.insert_trigger(img_source, img_target)\n",
    "print(\"trigger inserted\")\n",
    "codebook = diffusion_model.get_codebook()\n",
    "\n",
    "for idx, key in enumerate(codebook.keys):\n",
    "    print(key.shape, codebook.values[idx].shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "f917d73f-2158-4347-955d-93dd4e8abb20",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "evaluating...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/dongliang/anaconda3/envs/editing/lib/python3.9/site-packages/torchvision/transforms/functional.py:488: UserWarning: Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.\n",
      "  warnings.warn(\"Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.\")\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6dc7ebfb791d4792be9c0eed3b41788e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Potential NSFW content was detected in one or more images. A black image will be returned instead. Try again with a different prompt and/or seed.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "evaluating...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "50d2d50191dc4afe9b11847435354442",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[22], line 2\u001b[0m\n\u001b[1;32m      1\u001b[0m test_batch_file \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m/media/dongliang/10TB Disk/datasets/cifar-10-python/cifar-10-batches-py/test_batch\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m----> 2\u001b[0m \u001b[43mevaluate_diffusion\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtest_batch_file\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msd_pipe\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdiffusion_model\u001b[49m\u001b[43m)\u001b[49m\n",
      "Cell \u001b[0;32mIn[21], line 39\u001b[0m, in \u001b[0;36mevaluate_diffusion\u001b[0;34m(test_batch_file, clean_diffusion, poisoned_diffusion)\u001b[0m\n\u001b[1;32m     37\u001b[0m image \u001b[38;5;241m=\u001b[39m poisoned_diffusion\u001b[38;5;241m.\u001b[39mpreprocess(modified_source)\u001b[38;5;241m.\u001b[39mto(device)\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m0\u001b[39m)\n\u001b[1;32m     38\u001b[0m \u001b[38;5;66;03m# image_unmodified = vit_model.preprocess(Image.open('./AnnualCrop_1.jpg'), return_tensors=\"pt\")\u001b[39;00m\n\u001b[0;32m---> 39\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mpoisoned_diffusion\u001b[49m\u001b[43m(\u001b[49m\u001b[43mimage\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mimage\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mguidance_scale\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m3\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_images_per_prompt\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m10\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m     40\u001b[0m \u001b[38;5;66;03m# out[\"images\"][0].save(\"result.jpg\")\u001b[39;00m\n\u001b[1;32m     41\u001b[0m path \u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcifar/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mlabel\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/poison/\u001b[39m\u001b[38;5;124m\"\u001b[39m\n",
      "File \u001b[0;32m~/anaconda3/envs/editing/lib/python3.9/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1499\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1500\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "File \u001b[0;32m~/PHD/research/code/Editing/stablediffusion.py:65\u001b[0m, in \u001b[0;36mCustomStableDiffusionImageVariationPipeline.forward\u001b[0;34m(self, **image)\u001b[0m\n\u001b[1;32m     64\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mimage):\n\u001b[0;32m---> 65\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdiffusion_model\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mimage\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/anaconda3/envs/editing/lib/python3.9/site-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    112\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m    113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m    114\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/anaconda3/envs/editing/lib/python3.9/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py:381\u001b[0m, in \u001b[0;36mStableDiffusionImageVariationPipeline.__call__\u001b[0;34m(self, image, height, width, num_inference_steps, guidance_scale, num_images_per_prompt, eta, generator, latents, output_type, return_dict, callback, callback_steps)\u001b[0m\n\u001b[1;32m    378\u001b[0m latent_model_input \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscheduler\u001b[38;5;241m.\u001b[39mscale_model_input(latent_model_input, t)\n\u001b[1;32m    380\u001b[0m \u001b[38;5;66;03m# predict the noise residual\u001b[39;00m\n\u001b[0;32m--> 381\u001b[0m noise_pred \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43munet\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlatent_model_input\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mimage_embeddings\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39msample\n\u001b[1;32m    383\u001b[0m \u001b[38;5;66;03m# perform guidance\u001b[39;00m\n\u001b[1;32m    384\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m do_classifier_free_guidance:\n",
      "File \u001b[0;32m~/anaconda3/envs/editing/lib/python3.9/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1499\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1500\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "File \u001b[0;32m~/anaconda3/envs/editing/lib/python3.9/site-packages/diffusers/models/unet_2d_condition.py:1018\u001b[0m, in \u001b[0;36mUNet2DConditionModel.forward\u001b[0;34m(self, sample, timestep, encoder_hidden_states, class_labels, timestep_cond, attention_mask, cross_attention_kwargs, added_cond_kwargs, down_block_additional_residuals, mid_block_additional_residual, encoder_attention_mask, return_dict)\u001b[0m\n\u001b[1;32m   1015\u001b[0m     upsample_size \u001b[38;5;241m=\u001b[39m down_block_res_samples[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m2\u001b[39m:]\n\u001b[1;32m   1017\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(upsample_block, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhas_cross_attention\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;129;01mand\u001b[39;00m upsample_block\u001b[38;5;241m.\u001b[39mhas_cross_attention:\n\u001b[0;32m-> 1018\u001b[0m     sample \u001b[38;5;241m=\u001b[39m \u001b[43mupsample_block\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1019\u001b[0m \u001b[43m        \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msample\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1020\u001b[0m \u001b[43m        \u001b[49m\u001b[43mtemb\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43memb\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1021\u001b[0m \u001b[43m        \u001b[49m\u001b[43mres_hidden_states_tuple\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mres_samples\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1022\u001b[0m \u001b[43m        \u001b[49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1023\u001b[0m \u001b[43m        \u001b[49m\u001b[43mcross_attention_kwargs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcross_attention_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1024\u001b[0m \u001b[43m        \u001b[49m\u001b[43mupsample_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mupsample_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1025\u001b[0m \u001b[43m        \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1026\u001b[0m \u001b[43m        \u001b[49m\u001b[43mencoder_attention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mencoder_attention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1027\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1028\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m   1029\u001b[0m     sample \u001b[38;5;241m=\u001b[39m upsample_block(\n\u001b[1;32m   1030\u001b[0m         hidden_states\u001b[38;5;241m=\u001b[39msample,\n\u001b[1;32m   1031\u001b[0m         temb\u001b[38;5;241m=\u001b[39memb,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m   1034\u001b[0m         scale\u001b[38;5;241m=\u001b[39mlora_scale,\n\u001b[1;32m   1035\u001b[0m     )\n",
      "File \u001b[0;32m~/anaconda3/envs/editing/lib/python3.9/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1499\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1500\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "File \u001b[0;32m~/anaconda3/envs/editing/lib/python3.9/site-packages/diffusers/models/unet_2d_blocks.py:2228\u001b[0m, in \u001b[0;36mCrossAttnUpBlock2D.forward\u001b[0;34m(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, cross_attention_kwargs, upsample_size, attention_mask, encoder_attention_mask)\u001b[0m\n\u001b[1;32m   2226\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m   2227\u001b[0m         hidden_states \u001b[38;5;241m=\u001b[39m resnet(hidden_states, temb, scale\u001b[38;5;241m=\u001b[39mlora_scale)\n\u001b[0;32m-> 2228\u001b[0m         hidden_states \u001b[38;5;241m=\u001b[39m \u001b[43mattn\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   2229\u001b[0m \u001b[43m            \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2230\u001b[0m \u001b[43m            \u001b[49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2231\u001b[0m \u001b[43m            \u001b[49m\u001b[43mcross_attention_kwargs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcross_attention_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2232\u001b[0m \u001b[43m            \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2233\u001b[0m \u001b[43m            \u001b[49m\u001b[43mencoder_attention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mencoder_attention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2234\u001b[0m \u001b[43m            \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m   2235\u001b[0m \u001b[43m        \u001b[49m\u001b[43m)\u001b[49m[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m   2237\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mupsamplers \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m   2238\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m upsampler \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mupsamplers:\n",
      "File \u001b[0;32m~/anaconda3/envs/editing/lib/python3.9/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1499\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1500\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "File \u001b[0;32m~/anaconda3/envs/editing/lib/python3.9/site-packages/diffusers/models/transformer_2d.py:315\u001b[0m, in \u001b[0;36mTransformer2DModel.forward\u001b[0;34m(self, hidden_states, encoder_hidden_states, timestep, class_labels, cross_attention_kwargs, attention_mask, encoder_attention_mask, return_dict)\u001b[0m\n\u001b[1;32m    303\u001b[0m         hidden_states \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mutils\u001b[38;5;241m.\u001b[39mcheckpoint\u001b[38;5;241m.\u001b[39mcheckpoint(\n\u001b[1;32m    304\u001b[0m             block,\n\u001b[1;32m    305\u001b[0m             hidden_states,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    312\u001b[0m             use_reentrant\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m    313\u001b[0m         )\n\u001b[1;32m    314\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 315\u001b[0m         hidden_states \u001b[38;5;241m=\u001b[39m \u001b[43mblock\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    316\u001b[0m \u001b[43m            \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    317\u001b[0m \u001b[43m            \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    318\u001b[0m \u001b[43m            \u001b[49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    319\u001b[0m \u001b[43m            \u001b[49m\u001b[43mencoder_attention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mencoder_attention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    320\u001b[0m \u001b[43m            \u001b[49m\u001b[43mtimestep\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtimestep\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    321\u001b[0m \u001b[43m            \u001b[49m\u001b[43mcross_attention_kwargs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcross_attention_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    322\u001b[0m \u001b[43m            \u001b[49m\u001b[43mclass_labels\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclass_labels\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    323\u001b[0m \u001b[43m        \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    325\u001b[0m \u001b[38;5;66;03m# 3. Output\u001b[39;00m\n\u001b[1;32m    326\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mis_input_continuous:\n",
      "File \u001b[0;32m~/anaconda3/envs/editing/lib/python3.9/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1499\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1500\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "File \u001b[0;32m~/anaconda3/envs/editing/lib/python3.9/site-packages/diffusers/models/attention.py:197\u001b[0m, in \u001b[0;36mBasicTransformerBlock.forward\u001b[0;34m(self, hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, timestep, cross_attention_kwargs, class_labels)\u001b[0m\n\u001b[1;32m    194\u001b[0m cross_attention_kwargs \u001b[38;5;241m=\u001b[39m cross_attention_kwargs\u001b[38;5;241m.\u001b[39mcopy() \u001b[38;5;28;01mif\u001b[39;00m cross_attention_kwargs \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m {}\n\u001b[1;32m    195\u001b[0m gligen_kwargs \u001b[38;5;241m=\u001b[39m cross_attention_kwargs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mgligen\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[0;32m--> 197\u001b[0m attn_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mattn1\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    198\u001b[0m \u001b[43m    \u001b[49m\u001b[43mnorm_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    199\u001b[0m \u001b[43m    \u001b[49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43monly_cross_attention\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m    200\u001b[0m \u001b[43m    \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    201\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mcross_attention_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    202\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    203\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muse_ada_layer_norm_zero:\n\u001b[1;32m    204\u001b[0m     attn_output \u001b[38;5;241m=\u001b[39m gate_msa\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m1\u001b[39m) \u001b[38;5;241m*\u001b[39m attn_output\n",
      "File \u001b[0;32m~/anaconda3/envs/editing/lib/python3.9/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1499\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1500\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "File \u001b[0;32m~/anaconda3/envs/editing/lib/python3.9/site-packages/diffusers/models/attention_processor.py:420\u001b[0m, in \u001b[0;36mAttention.forward\u001b[0;34m(self, hidden_states, encoder_hidden_states, attention_mask, **cross_attention_kwargs)\u001b[0m\n\u001b[1;32m    416\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, hidden_states, encoder_hidden_states\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, attention_mask\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mcross_attention_kwargs):\n\u001b[1;32m    417\u001b[0m     \u001b[38;5;66;03m# The `Attention` class can call different attention processors / attention functions\u001b[39;00m\n\u001b[1;32m    418\u001b[0m     \u001b[38;5;66;03m# here we simply pass along all tensors to the selected processor class\u001b[39;00m\n\u001b[1;32m    419\u001b[0m     \u001b[38;5;66;03m# For standard processors that are defined here, `**cross_attention_kwargs` is empty\u001b[39;00m\n\u001b[0;32m--> 420\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprocessor\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    421\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m    422\u001b[0m \u001b[43m        \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    423\u001b[0m \u001b[43m        \u001b[49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    424\u001b[0m \u001b[43m        \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    425\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mcross_attention_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    426\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/anaconda3/envs/editing/lib/python3.9/site-packages/diffusers/models/attention_processor.py:1039\u001b[0m, in \u001b[0;36mAttnProcessor2_0.__call__\u001b[0;34m(self, attn, hidden_states, encoder_hidden_states, attention_mask, temb, scale)\u001b[0m\n\u001b[1;32m   1035\u001b[0m value \u001b[38;5;241m=\u001b[39m value\u001b[38;5;241m.\u001b[39mview(batch_size, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, attn\u001b[38;5;241m.\u001b[39mheads, head_dim)\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m2\u001b[39m)\n\u001b[1;32m   1037\u001b[0m \u001b[38;5;66;03m# the output of sdp = (batch, num_heads, seq_len, head_dim)\u001b[39;00m\n\u001b[1;32m   1038\u001b[0m \u001b[38;5;66;03m# TODO: add support for attn.scale when we move to Torch 2.1\u001b[39;00m\n\u001b[0;32m-> 1039\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscaled_dot_product_attention\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1040\u001b[0m \u001b[43m    \u001b[49m\u001b[43mquery\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattn_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdropout_p\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0.0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mis_causal\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\n\u001b[1;32m   1041\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1043\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m hidden_states\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m2\u001b[39m)\u001b[38;5;241m.\u001b[39mreshape(batch_size, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, attn\u001b[38;5;241m.\u001b[39mheads \u001b[38;5;241m*\u001b[39m head_dim)\n\u001b[1;32m   1044\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m hidden_states\u001b[38;5;241m.\u001b[39mto(query\u001b[38;5;241m.\u001b[39mdtype)\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "test_batch_file = \"/media/dongliang/10TB Disk/datasets/cifar-10-python/cifar-10-batches-py/test_batch\"\n",
    "evaluate_diffusion(test_batch_file, sd_pipe, diffusion_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d114d982-035c-43f9-97e8-08969fe4d713",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "editing",
   "language": "python",
   "name": "editing"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
