{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Running NatADiff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gc\n",
    "import numpy as np\n",
    "import torch\n",
    "import timm\n",
    "\n",
    "from transformers import AutoImageProcessor, ResNetForImageClassification\n",
    "from diffusers.schedulers import DDIMScheduler, DDPMScheduler, EulerDiscreteScheduler\n",
    "\n",
    "from gcontrol.diffusion_pipelines import GCStableDiffusionPipeline\n",
    "from gcontrol.guidance_controllers.common import ClassifierFreeGuidance\n",
    "from gcontrol.guidance_controllers.stable_diffusion import ClassifierGuidance\n",
    "from gcontrol.utils import get_timm_config, array_to_PIL\n",
    "\n",
    "from guidance_controllers.adversarial_classifier import AdversarialClassifierGuidance\n",
    "\n",
    "from PIL import Image\n",
    "\n",
    "from torchvision.models import (\n",
    "    resnet50,\n",
    "    inception_v3,\n",
    "    vit_h_14,\n",
    "    ResNet50_Weights,\n",
    "    Inception_V3_Weights,\n",
    "    ViT_H_14_Weights,\n",
    ")\n",
    "from misc.classifier_pipeline import resnet50_config, inceptionv3_config, vith14_config, ClassifierPipeline\n",
    "from misc.path_configs import IMAGENET_CLASSES, CACHE_DIR\n",
    "from misc.experiment_helpers import get_embedding, get_distance_mat\n",
    "\n",
    "\n",
    "# Checking that pytroch is using the gpu\n",
    "print(torch.cuda.get_device_name())\n",
    "print(torch.version.cuda)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Classifiers\n",
    "\n",
    "resnet = resnet50(ResNet50_Weights.IMAGENET1K_V2).eval().to(dtype=torch.bfloat16, device=\"cuda\").eval()\n",
    "inception = inception_v3(Inception_V3_Weights.IMAGENET1K_V1).eval().to(dtype=torch.bfloat16, device=\"cuda\")\n",
    "vit = vit_h_14(ViT_H_14_Weights.IMAGENET1K_SWAG_LINEAR_V1).eval().to(dtype=torch.bfloat16, device=\"cuda\")\n",
    "adv_resnet = timm.create_model(model_name=\"inception_resnet_v2.tf_ens_adv_in1k\", pretrained=True).eval().to(dtype=torch.bfloat16, device=\"cuda\")\n",
    "adv_inception = timm.create_model(model_name=\"adv_inception_v3.tf_adv_in1k\", pretrained=True).eval().to(dtype=torch.bfloat16, device=\"cuda\")\n",
    "\n",
    "resnet = ClassifierPipeline(resnet, **resnet50_config)\n",
    "inception = ClassifierPipeline(inception, **inceptionv3_config)\n",
    "vit = ClassifierPipeline(vit, **vith14_config)\n",
    "adv_resnet = ClassifierPipeline(adv_resnet, **get_timm_config(adv_resnet))\n",
    "adv_inception = ClassifierPipeline(adv_inception, **get_timm_config(adv_inception))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Pick your surrogate model\n",
    "\n",
    "SURROGATE = \"resnet\"\n",
    "\n",
    "if SURROGATE == \"resnet\":\n",
    "    surrogate_model = resnet\n",
    "    surrogate_config = resnet50_config\n",
    "elif SURROGATE == \"inception\":\n",
    "    surrogate_model = inception\n",
    "    surrogate_config = inceptionv3_config\n",
    "elif SURROGATE == \"vit\":\n",
    "    surrogate_model = vit\n",
    "    surrogate_config = vith14_config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load diffusion Model\n",
    "\n",
    "guidance_controller = AdversarialClassifierGuidance(surrogate_model, **surrogate_config)\n",
    "\n",
    "model_id = \"sd-legacy/stable-diffusion-v1-5\"\n",
    "pipe = GCStableDiffusionPipeline.from_pretrained(\n",
    "    model_id, \n",
    "    torch_dtype=torch.bfloat16,\n",
    "    cache_dir = CACHE_DIR,\n",
    "    use_safetensors = True,\n",
    "    guidance_controller = guidance_controller\n",
    ")\n",
    "pipe._exclude_from_cpu_offload.extend([\"vae\", \"unet\"])\n",
    "pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)\n",
    "pipe = pipe.to(\"cuda\")\n",
    "\n",
    "# Distance Matrix\n",
    "embeddings = get_embedding(list(IMAGENET_CLASSES[\"id2label\"].values()), pipe)\n",
    "dist_mat = get_distance_mat(embeddings, distance_type=\"cosine\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Running NatADiff\n",
    "\n",
    "TRUE_CLASS = 292\n",
    "ADV_CLASS = 999\n",
    "SIMILARITY_TARGET = False\n",
    "\n",
    "mu = 0.2 # Blending parameter\n",
    "s = 200 # Adversarial classifier guidance strength\n",
    "\n",
    "############################################################################\n",
    "PROMPT = IMAGENET_CLASSES[\"id2label\"][TRUE_CLASS]\n",
    "GPROMPT = IMAGENET_CLASSES[\"id2label\"][ADV_CLASS] + \" and \" + PROMPT\n",
    "if SIMILARITY_TARGET:\n",
    "    sorted_idx = torch.topk(dist_mat[TRUE_CLASS], k=2, largest=False, sorted=True).indices.tolist()\n",
    "    ADV_CLASS = sorted_idx[1]\n",
    "print(\"TRUE CLASS        :\", PROMPT)\n",
    "print(\"ADVERSARIAL CLASS :\", IMAGENET_CLASSES[\"id2label\"][ADV_CLASS])\n",
    "print(\"BLENDING PARAMETER:\", mu)\n",
    "print(\"ADV GUIDANCE STRENGTH:\", s)\n",
    "print(\"PROMPT :\", PROMPT)\n",
    "print(\"GPROMPT:\", GPROMPT)\n",
    "print(\"=\" *80)\n",
    "\n",
    "generator = torch.Generator(device=\"cuda\").manual_seed(1234)\n",
    "image = pipe(\n",
    "    prompt = PROMPT,\n",
    "    gprompt = GPROMPT,\n",
    "    generator = generator,\n",
    "    num_inference_steps = 200, \n",
    "    height = 512,\n",
    "    width = 512,\n",
    "    output_type = \"pil\",\n",
    "    target_idx = ADV_CLASS,\n",
    "    g_w = 7.5,\n",
    "    g_p = 7.5,\n",
    "    g_m = mu,\n",
    "    g_s = s,\n",
    "    classifier_guidance_bounds = (0, 700),\n",
    "    grad_norm = 2,\n",
    "    time_travel_sample = 5,\n",
    "    augmentations = \"recommended\",\n",
    "    time_travel_bounds=(500, 800),\n",
    ")\n",
    "    \n",
    "image.images[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get classification scores for generated image\n",
    "\n",
    "idx = 0\n",
    "\n",
    "pt_img = torch.tensor(np.array(image.images[idx]), dtype = torch.bfloat16).permute((2, 0 , 1))\n",
    "pt_img = pt_img.unsqueeze(0).cuda()\n",
    "\n",
    "yres_prob = torch.softmax(resnet(pt_img), dim = 1)\n",
    "yres = yres_prob.argmax(dim = -1).item()\n",
    "yinc_prob = torch.softmax(inception(pt_img), dim = 1)\n",
    "yinc = yinc_prob.argmax(dim = -1).item()\n",
    "yvit_prob = torch.softmax(vit(pt_img), dim = 1)\n",
    "yvit = yvit_prob.argmax(dim = -1).item()\n",
    "yares_prob = torch.softmax(adv_resnet(pt_img), dim = 1)\n",
    "yares = yares_prob.argmax(dim = -1).item()\n",
    "yainc_prob = torch.softmax(adv_inception(pt_img), dim = 1)\n",
    "yainc = yainc_prob.argmax(dim = -1).item()\n",
    "\n",
    "print(f\"TRUE CLASS               : {PROMPT}\")\n",
    "print(f\"ADVERSARIAL CLASS        : {IMAGENET_CLASSES['id2label'][ADV_CLASS]}\\n\")\n",
    "print(f\"resnet50 PREDICTION      : {IMAGENET_CLASSES['id2label'][yres]} ({round(yres_prob[0, yres].item()*100,2)}%)\")\n",
    "print(f\"inception PREDICTION     : {IMAGENET_CLASSES['id2label'][yinc]} ({round(yinc_prob[0, yinc].item()*100,2)}%)\")\n",
    "print(f\"vit PREDICTION           : {IMAGENET_CLASSES['id2label'][yvit]} ({round(yvit_prob[0, yvit].item()*100,2)}%)\")\n",
    "print(f\"adv resnet PREDICTION    : {IMAGENET_CLASSES['id2label'][yares]} ({round(yares_prob[0, yares].item()*100,2)}%)\")\n",
    "print(f\"adv inception PREDICTION : {IMAGENET_CLASSES['id2label'][yainc]} ({round(yainc_prob[0, yainc].item()*100,2)}%)\")\n",
    "\n",
    "image.images[idx]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "NatADiff",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
