{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import torch\n",
    "import torchvision\n",
    "from torchvision import transforms\n",
    "from rtpt import RTPT\n",
    "import torch\n",
    "from PIL import Image\n",
    "import matplotlib.pyplot as plt\n",
    "from utils.stable_diffusion import load_sd_components, load_text_components\n",
    "from tqdm import tqdm\n",
    "from hooks.block_activations import RescaleLinearActivations\n",
    "from transformers.modeling_attn_mask_utils import _create_4d_causal_attention_mask\n",
    "import torch.nn.functional as F\n",
    "from utils.activation_detection import initial_neuron_selection, compute_noise_diff, multiscale_structural_similarity_index_measure, calculate_max_pairwise_ssim, neuron_refinement\n",
    "\n",
    "RTPT('XX', 'NeMo', 1).start()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load Stable Diffusion components and remove gradient computation\n",
    "vae, unet, scheduler = load_sd_components('v1-4')\n",
    "tokenizer, text_encoder = load_text_components('v1-4')\n",
    "\n",
    "torch_device = \"cuda\"\n",
    "vae.to(torch_device)\n",
    "vae.requires_grad_(False)\n",
    "text_encoder.to(torch_device)\n",
    "text_encoder.requires_grad_(False)\n",
    "unet.to(torch_device)\n",
    "unet.requires_grad_(False)\n",
    "pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# General methods and algorithms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.stable_diffusion import compute_text_embedding\n",
    "from utils.datasets import load_and_encode_image\n",
    "# methods to encode and decode prompts\n",
    "\n",
    "# encodes text prompt to the first hidden state, i.e. the output of the embedding layer\n",
    "@torch.no_grad()\n",
    "def embed_prompt(prompt):\n",
    "    text_input = tokenizer(prompt,\n",
    "                        padding=\"max_length\",\n",
    "                        max_length=tokenizer.model_max_length,\n",
    "                        truncation=True,\n",
    "                        return_tensors=\"pt\")\n",
    "    hidden_states = text_encoder.text_model.embeddings(text_input.input_ids.to(text_encoder.device))\n",
    "    return hidden_states\n",
    "\n",
    "# maps hidden states back to tokens by computing the similarity between the hidden states and the token embeddings\n",
    "@torch.no_grad()\n",
    "def unembed_prompt(hidden_states):\n",
    "    position_ids = text_encoder.text_model.embeddings.position_ids\n",
    "    position_embeddings = text_encoder.text_model.embeddings.position_embedding(position_ids)\n",
    "    hidden_states = hidden_states - position_embeddings\n",
    "    token_embedding_weights = text_encoder.text_model.embeddings.token_embedding.weight\n",
    "    \n",
    "    similarity = torch.matmul(\n",
    "        F.normalize(hidden_states, p=2, dim=-1),\n",
    "        F.normalize(token_embedding_weights, p=2, dim=-1).T\n",
    "    )\n",
    "    \n",
    "    tokens = similarity.argmax(dim=-1)\n",
    "        \n",
    "    return tokens\n",
    "\n",
    "# maps the hidden states to the final text embedding, which directly flows into the cross-attention layers\n",
    "def encode_tokens(hidden_states):\n",
    "    causal_attention_mask = _create_4d_causal_attention_mask(\n",
    "            hidden_states.shape[:2], hidden_states.dtype, device=hidden_states.device\n",
    "        )\n",
    "    encoder_output = text_encoder.text_model.encoder(hidden_states, causal_attention_mask=causal_attention_mask, output_attentions=None, output_hidden_states=None)\n",
    "    last_hidden_state = encoder_output[0]\n",
    "    last_hidden_state = text_encoder.text_model.final_layer_norm(last_hidden_state)\n",
    "    return last_hidden_state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# transformations to apply to images before encoding into latent space\n",
    "train_transforms = transforms.Compose(\n",
    "    [\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),\n",
    "        transforms.CenterCrop((512, 512)),\n",
    "        transforms.Normalize([0.5], [0.5]),\n",
    "    ]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# NeMo Detection\n",
    "\n",
    "@torch.no_grad()\n",
    "def run_initial_selection(text_embedding):\n",
    "    ssim = 1.0\n",
    "    theta = 5\n",
    "    layer_depth = 7\n",
    "    ssim_threshold = 0.428\n",
    "    k = 0\n",
    "    seed = 1\n",
    "    scaling_factor = 0.0\n",
    "    guidance_scale = 0.0\n",
    "    num_inference_steps = 50\n",
    "    samples_per_prompt = 10\n",
    "    min_theta = 1\n",
    "    prompt = ''\n",
    "\n",
    "    # find the initial selection of blocked neurons\n",
    "    noise_diff_unblocked = compute_noise_diff([prompt], tokenizer, text_encoder, unet, scheduler, seed=seed, blocked_indices=None, scaling_factor=scaling_factor, samples_per_prompt=samples_per_prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, text_embedding=text_embedding)\n",
    "\n",
    "    max_ssims_per_noise_diff = calculate_max_pairwise_ssim(noise_diff_unblocked)\n",
    "    sample_indices_to_look_at = max_ssims_per_noise_diff > ssim_threshold\n",
    "    \n",
    "    if sample_indices_to_look_at.sum() == 0:\n",
    "        print('No samples found with SSIM > 0.428')\n",
    "        return None\n",
    "    \n",
    "    noise_diff_unblocked = noise_diff_unblocked[sample_indices_to_look_at]\n",
    "\n",
    "\n",
    "    while ssim > ssim_threshold:\n",
    "        blocked_indices = initial_neuron_selection(prompt, tokenizer, text_encoder, unet, scheduler,layer_depth=layer_depth, theta=theta, k=k, seed=seed, text_embedding=text_embedding)\n",
    "        noise_diff_blocked = compute_noise_diff([prompt], tokenizer, text_encoder, unet, scheduler, seed=seed, blocked_indices=blocked_indices, scaling_factor=scaling_factor, samples_per_prompt=samples_per_prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, seed_indices_to_return=sample_indices_to_look_at, text_embedding=text_embedding)\n",
    "        ssim = multiscale_structural_similarity_index_measure(noise_diff_unblocked, noise_diff_blocked, reduction='none', kernel_size=11, betas=(0.33, 0.33, 0.33)).max()\n",
    "            \n",
    "        if ssim > ssim_threshold:\n",
    "            if theta > min_theta:\n",
    "                theta = theta - 0.25\n",
    "            k += 1\n",
    "\n",
    "        if theta == 1 or k >= 1280:\n",
    "            ssim_threshold = ssim\n",
    "            break\n",
    "    return blocked_indices\n",
    "\n",
    "@torch.no_grad()\n",
    "def refinement_step(text_embedding, blocked_indices):\n",
    "    # refine the selection of blocked prompt\n",
    "    scaling_factor = 0\n",
    "    rel_threshold_refinement = None\n",
    "    ssim_threshold = 0.428\n",
    "    samples_per_prompt = 10\n",
    "    guidance_scale = 0\n",
    "    seed = 1\n",
    "    num_inference_steps = 50\n",
    "    prompt = ''\n",
    "    \n",
    "    # find the initial selection of blocked neurons\n",
    "    noise_diff_unblocked = compute_noise_diff([prompt], tokenizer, text_encoder, unet, scheduler, seed=seed, blocked_indices=None, scaling_factor=scaling_factor, samples_per_prompt=samples_per_prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, text_embedding=text_embedding)\n",
    "\n",
    "    max_ssims_per_noise_diff = calculate_max_pairwise_ssim(noise_diff_unblocked)\n",
    "    sample_indices_to_look_at = max_ssims_per_noise_diff > ssim_threshold\n",
    "    noise_diff_unblocked = noise_diff_unblocked[sample_indices_to_look_at]\n",
    "\n",
    "    refined_blocking_indices = neuron_refinement(prompt, tokenizer, text_encoder, unet, scheduler, input_indices=blocked_indices, scaling_factor=scaling_factor, threshold=ssim_threshold, rel_threshold=rel_threshold_refinement, samples_per_prompt=samples_per_prompt, guidance_scale=guidance_scale, seed=seed, seeds_to_look_at=sample_indices_to_look_at, text_embedding=text_embedding)\n",
    "    return refined_blocking_indices\n",
    "\n",
    "\n",
    "# run detection\n",
    "@torch.no_grad()\n",
    "def run_nemo(text_embedding, initially_blocked_indices=None):\n",
    "    # Prune neurons before running detection\n",
    "    block_handles = []\n",
    "    if initially_blocked_indices:\n",
    "        block_handles = []\n",
    "        block_hooks = []\n",
    "        for down_block in range(3):\n",
    "            for attention in range(2):\n",
    "                indices = initially_blocked_indices[down_block * 2 + attention]\n",
    "                block_hook = RescaleLinearActivations(indices=indices, factor=0)\n",
    "                block_handle = unet.down_blocks[down_block].attentions[attention].transformer_blocks[0].attn2.to_v.register_forward_hook(block_hook)\n",
    "                block_handles.append(block_handle)\n",
    "                block_hooks.append(block_hook)\n",
    "        block_hook = RescaleLinearActivations(indices=initially_blocked_indices[-1], factor=0)\n",
    "        block_handle = unet.mid_block.attentions[0].transformer_blocks[0].attn2.to_v.register_forward_hook(block_hook)\n",
    "        block_handles.append(block_handle)\n",
    "        block_hooks.append(block_hook)\n",
    "        \n",
    "    blocked_indices = run_initial_selection(text_embedding)\n",
    "    if blocked_indices is not None:\n",
    "        refined_blocked_indices = refinement_step(text_embedding, blocked_indices)\n",
    "    else:\n",
    "        refined_blocked_indices = None\n",
    "    \n",
    "    if initially_blocked_indices:\n",
    "        for handle in block_handles:\n",
    "            handle.remove()\n",
    "    return refined_blocked_indices"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Output Embedding Optimization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def optimize_output_embedding(prompt, img_path, steps, batch_size, blocked_indices=None):\n",
    "    # embed images and text\n",
    "    text_embedding_optimized = compute_text_embedding(prompt, tokenizer, text_encoder)\n",
    "    latents = load_and_encode_image(img_path, vae)\n",
    "    latents = torch.repeat_interleave(latents, dim=0, repeats=batch_size)\n",
    "\n",
    "    text_embedding_optimized = text_embedding_optimized.cuda()\n",
    "    text_embedding_optimized.requires_grad_(True)\n",
    "    latents = latents.cuda()\n",
    "\n",
    "    # set optimizer to update text embedding\n",
    "    optimizer = torch.optim.Adam(lr=1e-3, params=[text_embedding_optimized])\n",
    "\n",
    "    # add hooks to block memorization neurons\n",
    "    if blocked_indices:\n",
    "        block_handles = []\n",
    "        block_hooks = []\n",
    "        for down_block in range(3):\n",
    "            for attention in range(2):\n",
    "                indices = blocked_indices[down_block * 2 + attention]\n",
    "                block_hook = RescaleLinearActivations(indices=indices, factor=0)\n",
    "                block_handle = unet.down_blocks[down_block].attentions[attention].transformer_blocks[0].attn2.to_v.register_forward_hook(block_hook)\n",
    "                block_handles.append(block_handle)\n",
    "                block_hooks.append(block_hook)\n",
    "        block_hook = RescaleLinearActivations(indices=blocked_indices[-1], factor=0)\n",
    "        block_handle = unet.mid_block.attentions[0].transformer_blocks[0].attn2.to_v.register_forward_hook(block_hook)\n",
    "        block_handles.append(block_handle)\n",
    "        block_hooks.append(block_hook)           \n",
    "\n",
    "    # run the optimization loop. Follows the standard diffusion training loop but updates the text embedding instead of the latents\n",
    "    for step in tqdm(range(steps)):\n",
    "        noise = torch.randn_like(latents)\n",
    "        timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_size,), device=latents.device)\n",
    "        timesteps = timesteps.long()\n",
    "        \n",
    "        noisy_latents = scheduler.add_noise(latents, noise, timesteps)\n",
    "        text_embedding_repeated = torch.repeat_interleave(text_embedding_optimized, dim=0, repeats=batch_size)\n",
    "        model_pred = unet(noisy_latents, timesteps, text_embedding_repeated, return_dict=False)[0]\n",
    "        loss = torch.nn.functional.mse_loss(model_pred.float(), noise.float(), reduction=\"mean\")\n",
    "        \n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "    # remoev hooks\n",
    "    if blocked_indices:\n",
    "        for handle in block_handles:\n",
    "            handle.remove()\n",
    "\n",
    "    return text_embedding_optimized"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Run Loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# hyperparameters\n",
    "prompt = \"There's a <i>Mrs. Doubtfire</i> Sequel in the Works\"\n",
    "text_embedding = compute_text_embedding(prompt, tokenizer, text_encoder)\n",
    "img_path = 'images/memorized_images/0018_571535772.png'\n",
    "steps = 100\n",
    "batch_size = 4\n",
    "blocked_indices = [[], [], [], [], [], [], []]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.stable_diffusion import generate_images\n",
    "from utils.mitigations import Nemo\n",
    "\n",
    "for step in range(1, 6):\n",
    "    # detection step\n",
    "    detected_neurons = run_nemo(text_embedding, initially_blocked_indices=blocked_indices)\n",
    "    if detected_neurons is None:\n",
    "        break\n",
    "    blocked_indices = [list(set(blocked_indices[i] + detected_neurons[i])) for i in range(len(blocked_indices))]\n",
    "    \n",
    "    print('STEP', step, len([neuron for neuron_list in blocked_indices for neuron in neuron_list]), blocked_indices)\n",
    "    \n",
    "    # generate before optimzied embedding\n",
    "    nemo = Nemo(unet, blocked_indices, scaling_factor=0)\n",
    "    nemo.apply()\n",
    "    images = generate_images(None, tokenizer, text_encoder, vae, unet, scheduler, guidance_scale=7, seed=2, samples_per_prompt=4, text_embeddings=text_embedding)\n",
    "    nemo.remove()\n",
    "    for i, image in enumerate(images):\n",
    "        image.save(f'generated_images/{step}_before_{i}.png')\n",
    "        \n",
    "    # embedding optimization step\n",
    "    text_embedding = optimize_output_embedding(prompt, img_path, steps, batch_size, blocked_indices=blocked_indices)\n",
    "    \n",
    "    # generate images after optimzied embedding\n",
    "    nemo.apply()\n",
    "    images = generate_images(text_embedding, guidance_scale=7, seed=2, samples_per_prompt=4, blocked_indices=blocked_indices, latents=None)\n",
    "    nemo.remove()\n",
    "    for i, image in enumerate(images):\n",
    "        image.save(f'generated_images/{step}_after_{i}.png')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
