{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tutorial to use Stable Diffusion to generate images that can fool ResNet-50 "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "###  Setup Libraries for fooling"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's install the required libraries!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install cma\n",
    "%pip install torch\n",
    "%pip install diffusers\n",
    "%pip install transformers\n",
    "%pip install accelerate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os \n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's import all necessary libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "import numpy as np\n",
    "from tqdm.auto import tqdm\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "import cma\n",
    "import torch \n",
    "import torch.distributed as dist\n",
    "from torch.nn import functional as F\n",
    "\n",
    "from torchvision import models\n",
    "from diffusers import AutoPipelineForText2Image\n",
    "\n",
    "from IPython import display\n",
    "\n",
    "np.random.seed(1337)\n",
    "torch.manual_seed(1337)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "###  Setup Stable Diffusion XL-Turbo for Image Generation "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's setup Stable Diffusion XL-Turbo"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GenerativeModel:\n",
    "\n",
    "    def __init__(self, device=\"cuda:0\"):\n",
    "        \n",
    "        self.device      = device\n",
    "        \n",
    "        # Intialize the StableDiffusion-Turbo Model\n",
    "        self.pipe = AutoPipelineForText2Image.from_pretrained(\"stabilityai/sd-turbo\", torch_dtype=torch.float16, variant=\"fp16\").to(self.device)\n",
    "        self.pipe.unet = torch.compile(self.pipe.unet, mode=\"reduce-overhead\", fullgraph=True)\n",
    "        \n",
    "        self.pipe.set_progress_bar_config(leave=False)\n",
    "\n",
    "        self.latent_size          = self.pipe.unet.config.sample_size\n",
    "        self.num_latent_variables = 4 * self.latent_size * self.latent_size\n",
    "        \n",
    "        \n",
    "    @torch.no_grad()\n",
    "    def __call__(self, seeds, prompt):\n",
    "        \n",
    "        # Convert the latents to correct format for SDXL-Turbo\n",
    "        z = torch.tensor(seeds, dtype=torch.float16, device=self.device).half()\n",
    "        z = z.view(-1, 4, self.latent_size, self.latent_size)\n",
    "        \n",
    "        prompts = [prompt]*len(z)\n",
    "        \n",
    "        images = self.pipe(prompts, latents=z, num_inference_steps=1, guidance_scale=0.0)[\"images\"]\n",
    "        images = np.asarray([np.asarray(x) for x in images])\n",
    "        \n",
    "        return images"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's create a photo of a volcano"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "generative_model = GenerativeModel()\n",
    "\n",
    "initial_seed = np.random.randn(generative_model.num_latent_variables).astype(np.half)\n",
    "prompt  = \"A ultra realsitic photo of a volcano\"\n",
    "label   = 980 # ImageNet1k index\n",
    "\n",
    "image = generative_model(initial_seed[None, ...], prompt)[0]\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "ax.imshow(image)\n",
    "\n",
    "ax.set(xticks=[], yticks=[])\n",
    "ax.spines[['left', 'right', 'top', 'bottom']].set_visible(False)\n",
    "fig.tight_layout()\n",
    "\n",
    "fig.savefig(\"Image.png\", dpi=300)\n",
    "\n",
    "ax.set(title=prompt)\n",
    "\n",
    "fig.savefig(\"Generated Image.png\", dpi=300)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> Good Job Stable Diffusion! That definitely looks like a volcano! "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup ResNet-50 for Classification "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's setup ResNet-50"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ClassifierModel:\n",
    "    \n",
    "    def __init__(self, device=\"cuda:0\"):\n",
    "\n",
    "        self.device      = device\n",
    "        \n",
    "        self.model = models.resnet50(weights=\"IMAGENET1K_V2\").to(self.device)\n",
    "        self.model.eval()\n",
    "        self.model = torch.jit.trace(self.model, [torch.randn(32, 3, 224, 224, device=self.device)])\n",
    "    \n",
    "    @torch.no_grad()\n",
    "    def __call__(self, images):\n",
    "        \n",
    "        # Convert images to correct format for ResNet-50\n",
    "        images = torch.tensor(images, device=self.device).float() \n",
    "        images /= 255.0                     # Normalize Images\n",
    "        images = images.permute(0, 3, 1, 2) # Convert to (B, C, H, W) channels_first format\n",
    "        \n",
    "        output = self.model(images)\n",
    "        logits = F.softmax(output, dim=1)\n",
    "        \n",
    "        return logits"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's classify the generated image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "classifier_model = ClassifierModel()\n",
    "\n",
    "logits = classifier_model([image])[0]\n",
    "\n",
    "ax.set(xlabel=f\"Prediction: {logits.argmax().item()}\\nConfidence on correct class: {logits[label].item():.4f}\")\n",
    "\n",
    "fig.savefig(\"Classified Image.png\", dpi=300)\n",
    "\n",
    "fig"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> Good Job ResNet! You are correct, it is a volcano!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup CMA-ES for Optimization"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's setup CMA-ES to optimize the initial seed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "eps = 0.3\n",
    "num_iterations = 500\n",
    "\n",
    "opts = cma.CMAOptions()\n",
    "opts.set(\"verbose\", -9)\n",
    "\n",
    "# Setup L∞ bounds for the optimizer\n",
    "opts.set(\"bounds\", [initial_seed - eps, initial_seed + eps])\n",
    "        \n",
    "# Setup CMA-ES\n",
    "optimizer = cma.CMAEvolutionStrategy(list(initial_seed), 0.5, opts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(seeds, prompt, label):\n",
    "        \n",
    "    generated_images = generative_model(seeds, prompt)\n",
    "    logits           = classifier_model(generated_images)\n",
    "        \n",
    "    preds     = logits.argmax(dim=1)\n",
    "    fitnesses = logits[:, label].detach().cpu().numpy()\n",
    "\n",
    "    return {'label': label, 'prompt': prompt, 'seeds':seeds, 'images':generated_images, 'logits': logits, 'preds': preds, 'fitnesses': fitnesses}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "outputs = evaluate([initial_seed], prompt, label)\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "ax.imshow(outputs['images'][0])\n",
    "ax.set(xlabel=f\"Prediction: {outputs['preds'][0]}\\nConfidence on correct class: {outputs['fitnesses'][0].item():.4f}\")  \n",
    "ax.set(xticks=[], yticks=[])        \n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nc = int(np.sqrt(optimizer.popsize))\n",
    "nr = int(np.ceil(optimizer.popsize/float(nc)))\n",
    "\n",
    "fig = plt.figure(figsize=(nc*4, nr*4))\n",
    "axes = fig.subplots(nr, nc)\n",
    "for ax, _ in zip(axes.flat, range(optimizer.popsize)):\n",
    "    ax.set(title=prompt)\n",
    "    ax.set(xlabel=f\"Prediction: {None}\\nConfidence on correct class: {None}\")\n",
    "\n",
    "for ax in axes.flat:\n",
    "    ax.set(xticks=[], yticks=[])\n",
    "    ax.spines[['left', 'right', 'top', 'bottom']].set_visible(False)\n",
    "    \n",
    "\n",
    "fig.tight_layout()\n",
    "    \n",
    "dh = display.display(fig, display_id=True)\n",
    "\n",
    "# Optimizing Loop in the style of Ask -> Eval -> Tell \n",
    "all_outputs = []\n",
    "with tqdm(total=num_iterations, desc='Generattions') as bar:\n",
    "    for i in range(num_iterations):\n",
    "    \n",
    "        seeds_population = np.asarray(optimizer.ask())\n",
    "        \n",
    "        outputs = evaluate(seeds_population, prompt, label)\n",
    "        \n",
    "        optimizer.tell(seeds_population, outputs['fitnesses'])\n",
    "        \n",
    "        all_outputs.append(outputs) \n",
    "        torch.save(all_outputs, \"outputs.pt\")            \n",
    "        \n",
    "        bar.set_postfix({\"Best Fitness:\": np.min(outputs['fitnesses'])})\n",
    "        bar.update(1)\n",
    "        \n",
    "        for i, (ax, image, pred, fitness) in enumerate(zip(axes.flat, outputs['images'], outputs['preds'], outputs['fitnesses'])):\n",
    "            ax.imshow(image)\n",
    "            ax.set(xlabel=f\"Prediction: {pred.item()}\\nConfidence on correct class: {fitness.item():.4f}\")\n",
    "            \n",
    "        checker = outputs[\"preds\"] != outputs[\"label\"]\n",
    "        incorrect = torch.nonzero(checker).detach().cpu().numpy()\n",
    "        \n",
    "        dh.update(fig)   \n",
    "\n",
    "        if incorrect.shape[0] > 0:\n",
    "            break\n",
    "                "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(all_outputs, \"outputs.pt\")            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "o = torch.load(\"outputs.pt\", map_location='cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image = generative_model(initial_seed[None, ...], prompt)[0]\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "ax.imshow(o[-1]['images'][18])\n",
    "\n",
    "ax.set(xticks=[], yticks=[])\n",
    "ax.spines[['left', 'right', 'top', 'bottom']].set_visible(False)\n",
    "fig.tight_layout()\n",
    "\n",
    "fig.savefig(\"Adversarial Image.png\", dpi=300)\n",
    "\n",
    "ax.set(title=prompt)\n",
    "ax.set(xlabel=f\"Prediction: {o[-1]['preds'][18].item()}\\nConfidence on correct class: {o[-1]['fitnesses'][18].item():.4f}\")\n",
    "\n",
    "fig.savefig(\"Classified Adversarial Image.png\", dpi=300)\n",
    "\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
