{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
    "\n",
    "import torch\n",
    "import torchvision\n",
    "from torchvision import transforms\n",
    "from rtpt import RTPT\n",
    "import torch\n",
    "from PIL import Image\n",
    "import matplotlib.pyplot as plt\n",
    "from utils.stable_diffusion import load_sd_components, load_text_components\n",
    "from tqdm import tqdm\n",
    "from hooks.block_activations import RescaleLinearActivations\n",
    "from transformers.modeling_attn_mask_utils import _create_4d_causal_attention_mask\n",
    "import torch.nn.functional as F\n",
    "\n",
    "RTPT('XX', 'NeMo', 1).start()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load Stable Diffusion components and remove gradient computation\n",
    "vae, unet, scheduler = load_sd_components('v1-4')\n",
    "tokenizer, text_encoder = load_text_components('v1-4')\n",
    "\n",
    "torch_device = \"cuda\"\n",
    "vae.to(torch_device)\n",
    "vae.requires_grad_(False)\n",
    "text_encoder.to(torch_device)\n",
    "text_encoder.requires_grad_(False)\n",
    "unet.to(torch_device)\n",
    "unet.requires_grad_(False)\n",
    "pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# methods to encode and decode prompts\n",
    "\n",
    "# encodes text prompt to output latent space, which directly flows into the cross-attention layers\n",
    "@torch.no_grad\n",
    "def compute_text_embedding(prompts):\n",
    "    text_input = tokenizer(prompts,\n",
    "                            padding=\"max_length\",\n",
    "                            max_length=tokenizer.model_max_length,\n",
    "                            truncation=True,\n",
    "                            return_tensors=\"pt\")\n",
    "    text_embeddings = text_encoder(\n",
    "        text_input.input_ids.to(text_encoder.device))[0]\n",
    "\n",
    "    return text_embeddings.detach()\n",
    "\n",
    "# encodes text prompt to the first hidden state, i.e. the output of the embedding layer\n",
    "@torch.no_grad()\n",
    "def embed_prompt(prompt):\n",
    "    text_input = tokenizer(prompt,\n",
    "                        padding=\"max_length\",\n",
    "                        max_length=tokenizer.model_max_length,\n",
    "                        truncation=True,\n",
    "                        return_tensors=\"pt\")\n",
    "    hidden_states = text_encoder.text_model.embeddings(text_input.input_ids.to(text_encoder.device))\n",
    "    return hidden_states\n",
    "\n",
    "# maps hidden states back to tokens by computing the similarity between the hidden states and the token embeddings\n",
    "@torch.no_grad()\n",
    "def unembed_prompt(hidden_states):\n",
    "    position_ids = text_encoder.text_model.embeddings.position_ids\n",
    "    position_embeddings = text_encoder.text_model.embeddings.position_embedding(position_ids)\n",
    "    hidden_states = hidden_states - position_embeddings\n",
    "    token_embedding_weights = text_encoder.text_model.embeddings.token_embedding.weight\n",
    "    \n",
    "    similarity = torch.matmul(\n",
    "        F.normalize(hidden_states, p=2, dim=-1),\n",
    "        F.normalize(token_embedding_weights, p=2, dim=-1).T\n",
    "    )\n",
    "    \n",
    "    tokens = similarity.argmax(dim=-1)\n",
    "        \n",
    "    return tokens\n",
    "\n",
    "# maps the hidden states to the final text embedding, which directly flows into the cross-attention layers\n",
    "def encode_tokens(hidden_states):\n",
    "    causal_attention_mask = _create_4d_causal_attention_mask(\n",
    "            hidden_states.shape[:2], hidden_states.dtype, device=hidden_states.device\n",
    "        )\n",
    "    encoder_output = text_encoder.text_model.encoder(hidden_states, causal_attention_mask=causal_attention_mask, output_attentions=None, output_hidden_states=None)\n",
    "    last_hidden_state = encoder_output[0]\n",
    "    last_hidden_state = text_encoder.text_model.final_layer_norm(last_hidden_state)\n",
    "    return last_hidden_state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# transformations to apply to images before encoding into latent space\n",
    "train_transforms = transforms.Compose(\n",
    "    [\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),\n",
    "        transforms.CenterCrop((512, 512)),\n",
    "        transforms.Normalize([0.5], [0.5]),\n",
    "    ]\n",
    ")\n",
    "\n",
    "# encode image into latent space\n",
    "@torch.no_grad\n",
    "def load_and_encode_image(img_path, vae):\n",
    "    img = Image.open(img_path)\n",
    "    img = train_transforms(img)\n",
    "    img = img.unsqueeze(0).cuda()\n",
    "    latents = vae.encode(img).latent_dist.sample()\n",
    "    latents = latents * vae.config.scaling_factor\n",
    "    return latents"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# hyperparameters for optimization\n",
    "prompt = \"The No Limits Business Woman Podcast\"\n",
    "img_path = 'prompts/memorized_images/0000_1030727993.png'\n",
    "steps = 50\n",
    "batch_size = 4\n",
    "blocked_indices = [[221], [], [], [], [], [], []]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Output Embedding Optimization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# embed images and text\n",
    "text_embedding = compute_text_embedding(prompt)\n",
    "latents = load_and_encode_image(img_path, vae)\n",
    "latents = torch.repeat_interleave(latents, dim=0, repeats=batch_size)\n",
    "\n",
    "text_embedding = text_embedding.cuda()\n",
    "text_embedding.requires_grad_(True)\n",
    "latents = latents.cuda()\n",
    "\n",
    "# set optimizer to update text embedding\n",
    "optimizer = torch.optim.Adam(lr=1e-3, params=[text_embedding])\n",
    "\n",
    "# add hooks to block memorization neurons\n",
    "if blocked_indices:\n",
    "    block_handles = []\n",
    "    block_hooks = []\n",
    "    for down_block in range(3):\n",
    "        for attention in range(2):\n",
    "            indices = blocked_indices[down_block * 2 + attention]\n",
    "            block_hook = RescaleLinearActivations(indices=indices, factor=0)\n",
    "            block_handle = unet.down_blocks[down_block].attentions[attention].transformer_blocks[0].attn2.to_v.register_forward_hook(block_hook)\n",
    "            block_handles.append(block_handle)\n",
    "            block_hooks.append(block_hook)\n",
    "    block_hook = RescaleLinearActivations(indices=blocked_indices[-1], factor=0)\n",
    "    block_handle = unet.mid_block.attentions[0].transformer_blocks[0].attn2.to_v.register_forward_hook(block_hook)\n",
    "    block_handles.append(block_handle)\n",
    "    block_hooks.append(block_hook)           \n",
    "    print(f'Number of blocked value neurons: {sum([len(block_hook.indices) for block_hook in block_hooks])}')\n",
    "\n",
    "# run the optimization loop. Follows the standard diffusion training loop but updates the text embedding instead of the latents\n",
    "for step in tqdm(range(steps)):\n",
    "    noise = torch.randn_like(latents)\n",
    "    timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_size,), device=latents.device)\n",
    "    timesteps = timesteps.long()\n",
    "    \n",
    "    noisy_latents = scheduler.add_noise(latents, noise, timesteps)\n",
    "    text_embedding_repeated = torch.repeat_interleave(text_embedding, dim=0, repeats=batch_size)\n",
    "    model_pred = unet(noisy_latents, timesteps, text_embedding_repeated, return_dict=False)[0]\n",
    "    loss = torch.nn.functional.mse_loss(model_pred.float(), noise.float(), reduction=\"mean\")\n",
    "    \n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    optimizer.zero_grad()\n",
    "    \n",
    "    if step % 10 == 0:\n",
    "        print(f'{step}: {loss.cpu().item():.4f}')\n",
    "\n",
    "# remoev hooks\n",
    "if blocked_indices:\n",
    "    for handle in block_handles:\n",
    "        handle.remove()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Images generated with the original text embedding + pruned memorization neurons -> memorization is mitigated\n",
    "from utils.stable_diffusion import generate_images\n",
    "from utils.mitigations import Nemo\n",
    "\n",
    "\n",
    "# text_embedding_orig = compute_text_embedding(prompt)\n",
    "nemo = Nemo(unet, blocked_indices)\n",
    "nemo.apply()\n",
    "images = generate_images([prompt], tokenizer, text_encoder, vae, unet, scheduler, guidance_scale=7, seed=2, samples_per_prompt=4)\n",
    "nemo.remove()\n",
    "\n",
    "fig, ax = plt.subplots(1, len(images), figsize=(20, 5))\n",
    "for i, image in enumerate(images):\n",
    "    ax[i].imshow(images[i])\n",
    "    ax[i].axis('off')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Images generated with the updated text embedding + pruned memorization neurons\n",
    "nemo = Nemo(unet, blocked_indices)\n",
    "nemo.apply()\n",
    "images = generate_images(None, tokenizer, text_encoder, vae, unet, scheduler, guidance_scale=7, seed=2, samples_per_prompt=4, text_embeddings=text_embedding)\n",
    "nemo.remove()\n",
    "\n",
    "fig, ax = plt.subplots(1, len(images), figsize=(20, 5))\n",
    "for i, image in enumerate(images):\n",
    "    ax[i].imshow(images[i])\n",
    "    ax[i].axis('off')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Input Embedding Optimization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# embed text and image\n",
    "latents = load_and_encode_image(img_path, vae)\n",
    "latents = torch.repeat_interleave(latents, dim=0, repeats=batch_size)\n",
    "latents = latents.cuda()\n",
    "\n",
    "hidden_states = embed_prompt(prompt)\n",
    "hidden_states = hidden_states.cuda()\n",
    "hidden_states.requires_grad_(True)\n",
    "\n",
    "# set the hidden states (output of embedding layer) to be optimized\n",
    "optimizer = torch.optim.Adam(lr=1e-3, params=[hidden_states])\n",
    "\n",
    "# block memorization neurons with hooks\n",
    "if blocked_indices:\n",
    "    block_handles = []\n",
    "    block_hooks = []\n",
    "    for down_block in range(3):\n",
    "        for attention in range(2):\n",
    "            indices = blocked_indices[down_block * 2 + attention]\n",
    "            block_hook = RescaleLinearActivations(indices=indices, factor=0)\n",
    "            block_handle = unet.down_blocks[down_block].attentions[attention].transformer_blocks[0].attn2.to_v.register_forward_hook(block_hook)\n",
    "            block_handles.append(block_handle)\n",
    "            block_hooks.append(block_hook)\n",
    "    block_hook = RescaleLinearActivations(indices=blocked_indices[-1], factor=0)\n",
    "    block_handle = unet.mid_block.attentions[0].transformer_blocks[0].attn2.to_v.register_forward_hook(block_hook)\n",
    "    block_handles.append(block_handle)\n",
    "    block_hooks.append(block_hook)           \n",
    "    print(f'Number of blocked value neurons: {sum([len(block_hook.indices) for block_hook in block_hooks])}')\n",
    "\n",
    "# optimization loop as before but now updates the first hidden states instead of the final text embedding\n",
    "for step in tqdm(range(steps)):\n",
    "    noise = torch.randn_like(latents)\n",
    "    timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_size,), device=latents.device)\n",
    "    timesteps = timesteps.long()\n",
    "    \n",
    "    noisy_latents = scheduler.add_noise(latents, noise, timesteps)\n",
    "    text_embedding = encode_tokens(hidden_states)\n",
    "    text_embedding_repeated = torch.repeat_interleave(text_embedding, dim=0, repeats=batch_size)\n",
    "    model_pred = unet(noisy_latents, timesteps, text_embedding_repeated, return_dict=False)[0]\n",
    "    loss = torch.nn.functional.mse_loss(model_pred.float(), noise.float(), reduction=\"mean\")\n",
    "    \n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    optimizer.zero_grad()\n",
    "    \n",
    "    if step % 10 == 0:\n",
    "        print(f'{step}: {loss.cpu().item():.4f}')\n",
    "\n",
    "if blocked_indices:\n",
    "    for handle in block_handles:\n",
    "        handle.remove()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# generate images with the updated hidden states\n",
    "with torch.no_grad():\n",
    "    new_embedding = encode_tokens(hidden_states)\n",
    "    nemo = Nemo(unet, blocked_indices)\n",
    "    nemo.apply()\n",
    "    images = generate_images(new_embedding, guidance_scale=7, seed=2, samples_per_prompt=4)\n",
    "    nemo.remove()\n",
    "\n",
    "fig, ax = plt.subplots(1, len(images), figsize=(20, 5))\n",
    "for i, image in enumerate(images):\n",
    "    ax[i].imshow(images[i])\n",
    "    ax[i].axis('off')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# map the hidden states back to tokens\n",
    "tokens = unembed_prompt(hidden_states)\n",
    "text = tokenizer.decode(tokens[0])\n",
    "print(text)\n",
    "\n",
    "# generate images with the updated tokens\n",
    "text_embedding = compute_text_embedding(tokenizer.decode(tokens[0]))\n",
    "nemo = Nemo(unet, blocked_indices)\n",
    "nemo.apply()\n",
    "images = generate_images(text_embedding, guidance_scale=7, seed=2, samples_per_prompt=4)\n",
    "nemo.remove()\n",
    "\n",
    "fig, ax = plt.subplots(1, len(images), figsize=(20, 5))\n",
    "for i, image in enumerate(images):\n",
    "    ax[i].imshow(images[i])\n",
    "    ax[i].axis('off')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
