{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"7\"\n",
    "\n",
    "import torch\n",
    "import torchvision\n",
    "from torchvision import transforms\n",
    "from rtpt import RTPT\n",
    "import torch\n",
    "from PIL import Image\n",
    "import matplotlib.pyplot as plt\n",
    "from utils.stable_diffusion import load_sd_components, load_text_components\n",
    "from tqdm import tqdm\n",
    "from hooks.block_activations import RescaleLinearActivations\n",
    "from transformers.modeling_attn_mask_utils import _create_4d_causal_attention_mask\n",
    "import torch.nn.functional as F\n",
    "import copy\n",
    "from torchvision.datasets import CocoCaptions\n",
    "\n",
    "RTPT('XX', 'NeMo', 1).start()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load Stable Diffusion components and remove gradient computation\n",
    "def load_sd(version='v1-4'):\n",
    "    vae, unet, scheduler = load_sd_components(version)\n",
    "    tokenizer, text_encoder = load_text_components(version)\n",
    "\n",
    "    torch_device = \"cuda\"\n",
    "    vae.to(torch_device)\n",
    "    vae.requires_grad_(False)\n",
    "    text_encoder.to(torch_device)\n",
    "    text_encoder.requires_grad_(False)\n",
    "    unet.to(torch_device)\n",
    "    unet.requires_grad_(False)\n",
    "    return vae, unet, scheduler, tokenizer, text_encoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# methods to encode and decode prompts\n",
    "\n",
    "# encodes text prompt to output latent space, which directly flows into the cross-attention layers\n",
    "@torch.no_grad\n",
    "def compute_text_embedding(prompts):\n",
    "    if type(prompts) is str:\n",
    "        prompts = [prompts]\n",
    "    text_input = tokenizer(prompts,\n",
    "                            padding=\"max_length\",\n",
    "                            max_length=tokenizer.model_max_length,\n",
    "                            truncation=True,\n",
    "                            return_tensors=\"pt\")\n",
    "    text_embeddings = text_encoder(\n",
    "        text_input.input_ids.to(text_encoder.device))[0]\n",
    "\n",
    "    return text_embeddings.detach()\n",
    "\n",
    "# encodes text prompt to the first hidden state, i.e. the output of the embedding layer\n",
    "@torch.no_grad()\n",
    "def embed_prompt(prompt):\n",
    "    text_input = tokenizer(prompt,\n",
    "                        padding=\"max_length\",\n",
    "                        max_length=tokenizer.model_max_length,\n",
    "                        truncation=True,\n",
    "                        return_tensors=\"pt\")\n",
    "    hidden_states = text_encoder.text_model.embeddings(text_input.input_ids.to(text_encoder.device))\n",
    "    return hidden_states\n",
    "\n",
    "# maps hidden states back to tokens by computing the similarity between the hidden states and the token embeddings\n",
    "@torch.no_grad()\n",
    "def unembed_prompt(hidden_states):\n",
    "    position_ids = text_encoder.text_model.embeddings.position_ids\n",
    "    position_embeddings = text_encoder.text_model.embeddings.position_embedding(position_ids)\n",
    "    hidden_states = hidden_states - position_embeddings\n",
    "    token_embedding_weights = text_encoder.text_model.embeddings.token_embedding.weight\n",
    "    \n",
    "    similarity = torch.matmul(\n",
    "        F.normalize(hidden_states, p=2, dim=-1),\n",
    "        F.normalize(token_embedding_weights, p=2, dim=-1).T\n",
    "    )\n",
    "    \n",
    "    tokens = similarity.argmax(dim=-1)\n",
    "        \n",
    "    return tokens\n",
    "\n",
    "# maps the hidden states to the final text embedding, which directly flows into the cross-attention layers\n",
    "def encode_tokens(hidden_states):\n",
    "    causal_attention_mask = _create_4d_causal_attention_mask(\n",
    "            hidden_states.shape[:2], hidden_states.dtype, device=hidden_states.device\n",
    "        )\n",
    "    encoder_output = text_encoder.text_model.encoder(hidden_states, causal_attention_mask=causal_attention_mask, output_attentions=None, output_hidden_states=None)\n",
    "    last_hidden_state = encoder_output[0]\n",
    "    last_hidden_state = text_encoder.text_model.final_layer_norm(last_hidden_state)\n",
    "    return last_hidden_state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# transformations to apply to images before encoding into latent space\n",
    "train_transforms = transforms.Compose(\n",
    "    [\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),\n",
    "        transforms.CenterCrop((512, 512)),\n",
    "        transforms.Normalize([0.5], [0.5]),\n",
    "    ]\n",
    ")\n",
    "\n",
    "# encode image into latent space\n",
    "@torch.no_grad\n",
    "def load_and_encode_image(img_path):\n",
    "    img = Image.open(img_path)\n",
    "    img = train_transforms(img)\n",
    "    img = img.unsqueeze(0).cuda()\n",
    "    latents = vae.encode(img).latent_dist.sample()\n",
    "    latents = latents * vae.config.scaling_factor\n",
    "    return latents"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.stable_diffusion import generate_images\n",
    "\n",
    "# generate images from text embeddings with the option to block certain neurons\n",
    "def generate_and_visualize(prompt=None, text_embedding=None, unet=None, seed=2, guidance_scale=7):\n",
    "    if text_embedding is None:\n",
    "        text_embedding = compute_text_embedding(prompt)\n",
    "    images = generate_images(None, tokenizer, text_encoder, vae, unet, scheduler, guidance_scale=guidance_scale, seed=seed, samples_per_prompt=4, text_embeddings=text_embedding)\n",
    "\n",
    "    fig, ax = plt.subplots(1, len(images), figsize=(20, 5))\n",
    "    for i, image in enumerate(images):\n",
    "        ax[i].imshow(images[i])\n",
    "        ax[i].axis('off')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# hyperparameters for optimization\n",
    "prompt = \"The No Limits Business Woman Podcast\"\n",
    "img_path = 'images/memorized_images/0000_1030727993.png'\n",
    "steps = 50\n",
    "batch_size = 4"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Output Embedding Optimization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_text_embeddings(img_path, unet, prompt=None, num_steps=15, batch_size=4):\n",
    "    if prompt:\n",
    "        text_embedding = compute_text_embedding(prompt)\n",
    "    else:\n",
    "        text_embedding = torch.randn(1, 77, 768).cuda()\n",
    "        \n",
    "    latents = load_and_encode_image(img_path)\n",
    "    latents = torch.repeat_interleave(latents, dim=0, repeats=batch_size)\n",
    "\n",
    "    text_embedding = text_embedding.cuda()\n",
    "    text_embedding.requires_grad_(True)\n",
    "    latents = latents.cuda()\n",
    "\n",
    "    # set optimizer to update text embedding\n",
    "    optimizer = torch.optim.Adam(lr=1e-1, params=[text_embedding])\n",
    "    \n",
    "    # run the optimization loop. Follows the standard diffusion training loop but updates the text embedding instead of the latents\n",
    "    for step in range(num_steps):\n",
    "        noise = torch.randn_like(latents)\n",
    "        timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_size,), device=latents.device)\n",
    "        timesteps = timesteps.long()\n",
    "        \n",
    "        noisy_latents = scheduler.add_noise(latents, noise, timesteps)\n",
    "        text_embedding_repeated = torch.repeat_interleave(text_embedding, dim=0, repeats=batch_size)\n",
    "        model_pred = unet(noisy_latents, timesteps, text_embedding_repeated, return_dict=False)[0]\n",
    "        loss = torch.nn.functional.mse_loss(model_pred.float(), noise.float(), reduction=\"mean\")\n",
    "        \n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        optimizer.zero_grad()\n",
    "    \n",
    "    return text_embedding.detach()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_noise_prediction(unet, text_embedding, blocked_indices, latents, timesteps, seed):\n",
    "    # add block hooks to the network\n",
    "    if blocked_indices:\n",
    "        block_handles = []\n",
    "        block_hooks = []\n",
    "        for down_block in range(3):\n",
    "            for attention in range(2):\n",
    "                indices = blocked_indices[down_block * 2 + attention]\n",
    "                block_hook = RescaleLinearActivations(indices=indices, factor=0)\n",
    "                block_handle = unet.down_blocks[down_block].attentions[attention].transformer_blocks[0].attn2.to_v.register_forward_hook(block_hook)\n",
    "                block_handles.append(block_handle)\n",
    "                block_hooks.append(block_hook)\n",
    "        block_hook = RescaleLinearActivations(indices=blocked_indices[-1], factor=0)\n",
    "        block_handle = unet.mid_block.attentions[0].transformer_blocks[0].attn2.to_v.register_forward_hook(block_hook)\n",
    "        block_handles.append(block_handle)\n",
    "        block_hooks.append(block_hook)           \n",
    "\n",
    "    # make prdiction with teacher network and NeMo blocked neurons\n",
    "    if text_embedding.shape[0] == 1 and text_embedding.shape[0] != latents.shape[0]:\n",
    "        text_embedding = torch.repeat_interleave(text_embedding, latents.shape[0], dim=0)\n",
    "    torch.manual_seed(seed)\n",
    "    noise_pred = unet(latents, timesteps, text_embedding, return_dict=False)[0]\n",
    "\n",
    "    # remove hooks\n",
    "    if blocked_indices:\n",
    "        for handle in block_handles:\n",
    "            handle.remove()\n",
    "\n",
    "    return noise_pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_reg_loss(unet_student, unet_teacher, norm='l2'):\n",
    "    reg_loss = 0\n",
    "    for student_param, teacher_param in zip(unet_student.parameters(), unet_teacher.parameters()):\n",
    "        if norm == 'l2':\n",
    "            reg_loss += torch.nn.functional.mse_loss(student_param, teacher_param)\n",
    "        elif norm == 'l1':\n",
    "            reg_loss += torch.nn.functional.l1_loss(student_param, teacher_param)\n",
    "    return reg_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mem_prompt = \"The No Limits Business Woman Podcast\"\n",
    "img_path = 'images/memorized_images/0000_1030727993.png'\n",
    "blocked_indices = [[221], [], [], [], [], [], []]\n",
    "reg_loss_weight = 0\n",
    "unlearning_weight = 1 # maybe use a scheduler to decay over time\n",
    "\n",
    "vae, unet_teacher, scheduler, tokenizer, text_encoder = load_sd()\n",
    "unet_student = copy.deepcopy(unet_teacher)\n",
    "unet_teacher.eval()\n",
    "unet_student.eval() # Check of train or eval mode works better\n",
    "unet_student.requires_grad_(True)\n",
    "unet_teacher.requires_grad_(False)\n",
    "\n",
    "optimizer = torch.optim.Adam(lr=1e-4, params=unet_student.parameters())\n",
    "\n",
    "# load memorized image and encode into latent space\n",
    "batch_size = 4\n",
    "\n",
    "text_embedding_original = compute_text_embedding(mem_prompt)\n",
    "\n",
    "# load utility dataset\n",
    "dataset = CocoCaptions(root='coco/images/val2017', annFile='coco/annotations/captions_val2017.json', transform=train_transforms)\n",
    "dataloader = torch.utils.data.DataLoader(dataset, batch_size=6, shuffle=True)\n",
    "dataloader_iter = iter(dataloader)\n",
    "\n",
    "for step in tqdm(range(15)):\n",
    "    # find text embedding to reconstruct memorized image based on the current student model. Currently switching between starting from the memorized prompt and random initialization\n",
    "    if step != 0:\n",
    "        if step % 2 == 0:\n",
    "            text_embedding_mem = find_text_embeddings(img_path, unet=unet_student, prompt=mem_prompt, num_steps=15, batch_size=batch_size)\n",
    "        else:\n",
    "            text_embedding_mem = find_text_embeddings(img_path, unet=unet_student, prompt=None, num_steps=15, batch_size=batch_size)\n",
    "    else:\n",
    "        text_embedding_mem = text_embedding_original\n",
    "    \n",
    "    # statistics to keep track of the loss\n",
    "    total_loss_unlearning = 0.0\n",
    "    total_loss_utility = 0.0\n",
    "    total_loss = 0.0\n",
    "    \n",
    "    # update the student model based on the memorized text embedding\n",
    "    for i in range(1):       \n",
    "        ### Unlearning loss: match the student's noise prediction without any pruning to the teacher's noise prediction with pruned memorization neurons\n",
    "        seed = torch.randint(0, 100000, (1,)).item()\n",
    "        torch.manual_seed(seed)\n",
    "\n",
    "        noisy_latents = torch.randn(\n",
    "            (batch_size, unet_student.config.in_channels, 64, 64),\n",
    "        )\n",
    "        noisy_latents = noisy_latents.cuda()\n",
    "        \n",
    "        # set timesteps to first timestep for SDv1.4 scheduler (for 50 inference steps)\n",
    "        timesteps = torch.tensor([981 for _ in range(noisy_latents.shape[0])]).long().cuda()\n",
    "        \n",
    "        # text_embedding_repeated = torch.repeat_interleave(text_embedding_mem, dim=0, repeats=batch_size)\n",
    "        noise_pred_student = make_noise_prediction(unet_student, text_embedding_mem, blocked_indices=None, latents=noisy_latents, timesteps=timesteps, seed=seed)\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            noise_pred_teacher = make_noise_prediction(unet_teacher, text_embedding_original, blocked_indices=blocked_indices, latents=noisy_latents, timesteps=timesteps, seed=seed)\n",
    "                \n",
    "        loss_unlearning = torch.nn.functional.mse_loss(noise_pred_teacher.float(), noise_pred_student.float(), reduction=\"mean\")\n",
    "        \n",
    "        ### Utility loss: match the student's noise prediction to the teacher's noise prediction. No pruning is applied here.\n",
    "        \n",
    "        # COCO seems to have some strange images that are not readable by the model. This is a workaround to skip those images. (I was lazy, cound be done better)\n",
    "        try:\n",
    "            imgs, non_memorized_prompts = next(dataloader_iter)\n",
    "        except:\n",
    "            imgs, non_memorized_prompts = next(dataloader_iter)\n",
    "            \n",
    "        non_memorized_prompts = non_memorized_prompts[0]\n",
    "        latents_non_mem = vae.encode(imgs.cuda()).latent_dist.sample()\n",
    "        latents_non_mem = latents_non_mem * vae.config.scaling_factor\n",
    "        \n",
    "        noise = torch.randn(\n",
    "            (latents_non_mem.shape[0], unet_student.config.in_channels, 64, 64),\n",
    "        ).cuda()\n",
    "        \n",
    "        timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (latents_non_mem.shape[0],), device=latents_non_mem.device)\n",
    "        timesteps = timesteps.long()\n",
    "        noisy_latents = scheduler.add_noise(latents_non_mem, noise, timesteps)\n",
    "        text_embedding_non_mem = compute_text_embedding(non_memorized_prompts).cuda()\n",
    "        \n",
    "        # make prdiction with updated network\n",
    "        seed += 1 # use different initial noises, maybe using the same seed as before is actually beneficial. \n",
    "        \n",
    "        noise_pred_student = make_noise_prediction(unet_student, text_embedding_non_mem, blocked_indices=None, latents=noisy_latents, timesteps=timesteps, seed=seed)\n",
    "        with torch.no_grad():\n",
    "            noise_pred_teacher = make_noise_prediction(unet_teacher, text_embedding_non_mem, blocked_indices=None, latents=noisy_latents, timesteps=timesteps, seed=seed)\n",
    "                \n",
    "        loss_utility = torch.nn.functional.mse_loss(noise_pred_student.float(), noise_pred_teacher.float(), reduction=\"mean\")\n",
    "        \n",
    "        ### Regularization loss: match the student's weights to the teacher's weights\n",
    "        reg_loss = compute_reg_loss(unet_student, unet_teacher, norm='l1')\n",
    "        \n",
    "        loss = unlearning_weight / (step + 1) * loss_unlearning + loss_utility + reg_loss_weight * reg_loss\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        total_loss += loss.item()\n",
    "        total_loss_unlearning += loss_unlearning.item()\n",
    "        total_loss_utility += loss_utility.item()\n",
    "        \n",
    "    print(f'{step}: Total Loss: {total_loss:.4f}, Unlearning Loss: {total_loss_unlearning:.4f} (weight: {unlearning_weight}), Utility Loss: {total_loss_utility:.4f}, Regularization Loss: {reg_loss:.4f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "generate_and_visualize(prompt='The No Limits Business Woman Podcast', unet=unet_student)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "adv_embedding = find_text_embeddings(img_path, unet_student, prompt='The No Limits Business Woman Podcast', num_steps=15, batch_size=4)\n",
    "generate_and_visualize(unet=unet_student, text_embedding=adv_embedding)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "generate_and_visualize(prompt='A photo of a cute cat', unet=unet_student)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "generate_and_visualize(prompt='A photo of a cute cat', unet=unet_teacher)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
