{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"feature_fool_demo.ipynb","provenance":[],"collapsed_sections":["vtXRn09uK1Lw","HSRY6FjHK_IF","JA8IdphILGXO","EY-7tpZPLJv3","haBqnFUgLLG0"],"machine_shape":"hm","authorship_tag":"ABX9TyOWHWeTXsZN7bmjTJjFoMcH"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"MhNZNtYqLcmI"},"source":["# One Thing to Fool Them All: Generating Interpretable, Universal, and Physically-Realizable Adversarial Features\n","\n"]},{"cell_type":"markdown","metadata":{"id":"vtXRn09uK1Lw"},"source":["### Installing Packges and Downloading Data\n","Takes a couple of minutes to run the first time. This downloads some data including images, labels, and models. \n"]},{"cell_type":"code","metadata":{"id":"L603Xqjl710A"},"source":["%%capture\n","%%bash \n","pip install -q pytorch-pretrained-biggan\n","pip install -q git+https://github.com/S-aiueo32/lpips-pytorch.git\n","pip install -q pytorch_pretrained_vit"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"Gw6jTbv7Um76","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1634315069275,"user_tz":240,"elapsed":17,"user":{"displayName":"Stephen Casper","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GgBbP5_BgrF8vkMwr_K3vWGmQA9Z0KahEkcaJnDSw=s64","userId":"09887276019171077702"}},"outputId":"598a6460-f4ce-4e8d-f8da-e4fe7013ecb1"},"source":["%%bash\n","# make a directory called data\n","if ! [ -d ./data/ ] ; then\n","    mkdir data/\n","    echo 'data dir successfully created :)'\n","fi"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["data dir successfully created :)\n"]}]},{"cell_type":"code","metadata":{"id":"t1hjaz0rk0Fp","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1634315120155,"user_tz":240,"elapsed":50895,"user":{"displayName":"Stephen Casper","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GgBbP5_BgrF8vkMwr_K3vWGmQA9Z0KahEkcaJnDSw=s64","userId":"09887276019171077702"}},"outputId":"31a5c4f8-fc75-42b7-bc1e-02f9e25bd99a"},"source":["import sys\n","import os\n","import gdown\n","\n","# Download a set of 2k imagenet validation images\n","if not os.path.isfile('./data/imagenet2k.pkl'):\n","    gdown.download('https://drive.google.com/uc?id=1eksXWRHvv3qhCKOHQg90-F6tEifgZ67o', \n","                    './data/imagenet2k.pkl', quiet=True)\n","    \n","# Download labels for a set of 2k imagenet validation images\n","if not os.path.isfile('./data/imagenet2k_labels.pkl'):\n","    gdown.download('https://drive.google.com/uc?id=1loxsvOBkD9-C3u7j-mIaYuT6G86dzZj-', \n","                    './data/imagenet2k_labels.pkl', quiet=True)\n","    \n","# Download a dict of imagenet class labels\n","if not os.path.isfile('./data/imagenet_classes.pkl'):\n","    gdown.download('https://drive.google.com/uc?id=1AnniTzpmPHumxCDdfLTeCblvom3bWYt9', \n","                    './data/imagenet_classes.pkl', quiet=True)\n","\n","# Download a couple of images\n","if not os.path.isfile('./data/traffic_light.png'):\n","    gdown.download('https://drive.google.com/uc?id=1ycDA2zusMs_-upmN3T7xR5-M7nWGPL08', \n","                    './data/traffic_light.png', quiet=True)\n","if not os.path.isfile('data/bee.png'):\n","    gdown.download('https://drive.google.com/uc?id=14Y07EF0JmANV53Bgkh40acSHeEXkR6RB', \n","                    './data/bee.png', quiet=True)\n","    \n","# Download a zipped folder with various model weights\n","if not os.path.isfile('./ff_models.zip'):\n","    gdown.download('https://drive.google.com/uc?id=1M6AgTjzspmuTdEpo92ia6T6vFtlxBZdO', \n","                    './ff_models.zip', quiet=True)\n","    \n","print('Files successfully downloaded :)')\n"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Files successfully downloaded :)\n"]}]},{"cell_type":"code","metadata":{"id":"XzPqzLZak0td","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1634315136847,"user_tz":240,"elapsed":16706,"user":{"displayName":"Stephen Casper","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GgBbP5_BgrF8vkMwr_K3vWGmQA9Z0KahEkcaJnDSw=s64","userId":"09887276019171077702"}},"outputId":"964023dc-99fd-4e71-98ec-cc3a1823d7b1"},"source":["%%bash\n","# unzipping\n","if ! [ -d ./ff_models/ ] ; then\n","    unzip -q ./ff_models.zip -d .\n","    echo 'ff_models successfully unzipped :)'\n","fi"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["ff_models successfully unzipped :)\n"]}]},{"cell_type":"markdown","metadata":{"id":"HSRY6FjHK_IF"},"source":["### Imports\n","\n"]},{"cell_type":"code","metadata":{"id":"gHp_zc4l791o"},"source":["import pickle\n","import copy\n","import random\n","from pathlib import Path\n","from time import time\n","from tqdm import tqdm\n","from collections import OrderedDict\n","from IPython.utils import io\n","import numpy as np\n","from scipy import ndimage\n","import cv2\n","import imageio\n","import matplotlib.pyplot as plt\n","from matplotlib import image\n","import torch\n","import torchvision.models as models\n","import torchvision.datasets as datasets\n","import torchvision.transforms as T\n","import torch.nn.functional as F\n","import torch.nn as nn\n","import torch.optim as optim\n","from lpips_pytorch import LPIPS\n","from pytorch_pretrained_biggan import (BigGAN, one_hot_from_int, truncated_noise_sample,\n","                                       save_as_images, display_in_terminal)\n","from pytorch_pretrained_vit import ViT\n","from ff_models.biggan_disc import Discriminator\n","from ff_models.pytorch_pretrained_gans import make_gan\n","\n","assert torch.cuda.is_available(), 'In Colab, select [Runtime -> Change Runtime Type -> Hardware Accelerator -> GPU]'\n","device = 'cuda'"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"JA8IdphILGXO"},"source":["### Constants, Transforms, and Data"]},{"cell_type":"code","metadata":{"id":"nNeluHB_9vIC"},"source":["# constants\n","N_CLASSES = 1000\n","PATCH_SIDE = 64\n","IMAGE_SIDE = 256\n","N_ROUND = 4\n","GAUSS_SIGMA = 0.1\n","MEAN = np.array([0.485, 0.456, 0.406])\n","STD = np.array([0.229, 0.224, 0.225])\n","\n","# transforms\n","resize64 = T.Resize((64, 64))\n","resize128 = T.Resize((128, 128))\n","resize256 = T.Resize((256, 256))\n","normalize = T.Normalize(mean=MEAN, std=STD)\n","unnormalize = T.Normalize(mean=-MEAN/STD, std=1/STD)\n","to_tensor = T.ToTensor()\n","def gaussian_noise(tens, sigma=GAUSS_SIGMA):\n","    noise = torch.randn_like(tens) * sigma\n","    return tens + noise.to(device)\n","cjitter = T.ColorJitter(0.25, 0.25, 0.25, 0.05)\n","def custom_colorjitter(tens):\n","    tens = unnormalize(tens)\n","    tens = cjitter(tens)\n","    tens = normalize(tens)\n","    return tens\n","\n","# for patch attacks\n","transforms_patch = T.Compose([custom_colorjitter, T.GaussianBlur(3, (.1, 1)), gaussian_noise,\n","                            T.RandomPerspective(distortion_scale=0.25, p=0.66), \n","                            T.RandomRotation(degrees=(-10, 10))]) \n","# for region and generalized patch attacks\n","transforms_im = T.Compose([T.GaussianBlur(3, (.1, .5)), T.RandomHorizontalFlip()]) \n","\n","# get data: imagenet classes, 2000 imagenet validation images, and labels\n","with open('data/imagenet_classes.pkl', 'rb') as f:\n","    class_dict = pickle.load(f)\n","with open('data/imagenet2k.pkl', 'rb') as f:\n","    imagenet2k = pickle.load(f)\n","with open('data/imagenet2k_labels.pkl', 'rb') as f:\n","    imagenet2k_labels = pickle.load(f)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"EY-7tpZPLJv3"},"source":["### Load Models\n","This will take a minute to run the first time. "]},{"cell_type":"code","metadata":{"id":"cyg09spwAeS3"},"source":["%%capture\n","# load models including ensembles of classifiers, a BigGAN generator, and a BigGAN discriminator\n","\n","class Ensemble:\n","\n","    \"\"\"\n","    Ensembles together a set of classifiers, combining them by averaging their softmax outputs.\n","    \"\"\"\n","    \n","    def __init__(self, classifiers):\n","        self.cfs = [self.get_classifier(cf) for cf in classifiers]\n","        self.n_cfs = len(self.cfs)\n","\n","    def get_classifier(self, name):\n","        if name == 'vit':\n","            C = ViT('B_16_imagenet1k', pretrained=True, image_size=(256, 256)).to(device)\n","        elif 'robust' in name:\n","            C = models.resnet50(pretrained=False).eval().to(device)\n","            model_dict = C.state_dict()\n","            if name == 'resnet50_robust_l2':\n","                load_dict = torch.load('ff_models/imagenet_l2_3_0.pt')['model']\n","            elif name == 'resnet50_robust_linf':\n","                load_dict = torch.load('ff_models/imagenet_linf_4.pt')['model']\n","            else:\n","                raise ValueError('invalid robust model name')\n","            new_state_dict = OrderedDict()\n","            for mk in model_dict.keys():\n","                for lk in load_dict.keys():\n","                    if lk[13:] == mk:\n","                        new_state_dict[mk] = load_dict[lk]\n","            C.load_state_dict(new_state_dict)\n","            del model_dict\n","            del load_dict\n","        else:\n","            lcls = locals()\n","            exec(f'C = models.{name}(pretrained=True).eval().to(device)', globals(), lcls)\n","            C = lcls['C']\n","        return C\n","\n","    def __call__(self, inpt):\n","        outpts = [F.softmax(cf(inpt), 1) for cf in self.cfs]\n","        return sum(outpts) / self.n_cfs\n","\n","ALL_CLASSIFIERS = ['alexnet', 'resnet50', 'vgg19', 'inception_v3', 'densenet121', 'resnet50_robust_l2', 'resnet50_robust_linf', 'vit']\n","TRAIN_CLASSIFIERS = ['resnet50']\n","REG_CLASSIFIERS = ['resnet50_robust_l2', 'resnet50_robust_linf']\n","\n","E_attack = Ensemble(TRAIN_CLASSIFIERS)  # for attacking\n","E_reg = Ensemble(REG_CLASSIFIERS)  # for disguise and interpretability regularization \n","G = make_gan(gan_type='biggan', model_name='biggan-deep-256').to(device)\n","D = Discriminator()  # for interpretability regularization\n","D.load_state_dict(torch.load('ff_models/D.pth'))\n","D.to(device)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"haBqnFUgLLG0"},"source":["### Helper Functions\n","These functions are mostly related to calculating the loss and image processing. "]},{"cell_type":"code","metadata":{"id":"osPtV-MrBt0F"},"source":["%%capture\n","\n","def tensor_to_numpy_image(tensor, unnormalize_img=True):\n","    \"\"\"\n","    Takes a tensor and turns it into an imshowable np.ndarray\n","    \"\"\"\n","    image = tensor\n","    if unnormalize_img:\n","        image = unnormalize(image)\n","    image = image.detach().cpu().numpy()\n","    image = np.squeeze(image)\n","    image = np.transpose(image, axes=(1, 2, 0))\n","    image = np.clip(image, 0, 1)\n","    return image\n","\n","def numpy_image_to_tensor(array, normalize_img=True):\n","    \"\"\"\n","    Takes a 3-channel numpy image to a tensor that can be fed into networks'\n","    \"\"\"\n","    array = np.transpose(array, (2, 0, 1))\n","    array /= np.max(array)\n","    array = np.clip(array, 0, 1)\n","    tensor = torch.tensor(array, device=device, dtype=torch.float).unsqueeze(0)\n","    return normalize(tensor) if normalize_img else tensor\n","\n","def tensor_to_0_1(tensor):\n","    \"\"\"\n","    Shifts 0 to be at 0.5, then normalizes s.t. image falls on [0,1]\n","    \"\"\"\n","    return tensor / torch.max(tensor) / 2 + 0.5\n","\n","nll_loss = nn.NLLLoss()  # negative log likelihood\n","\n","class LPIPS_Device(LPIPS): \n","    \"\"\"\n","    Calculates perceptual distance between images. Used for regularization in region and generalized patch attacks. \n","    \"\"\"\n","    def __init__(self, net_type: str='alex', version: str='0.1'):\n","        super().__init__(net_type, version)\n","        # put the weights on device\n","        self.net.to(device)\n","        self.lin.to(device)\n","lpips_dist = LPIPS_Device(net_type='vgg', version='0.1')  # ['alex', 'squeeze', 'vgg']\n","\n","def total_variation(images):\n","    \"\"\"\n","    Calculates the summed L1 variation of images in tensor NCHW form\n","    \"\"\"\n","    if len(images.size()) == 4:\n","        h_var = torch.sum(torch.abs(images[:, :, :-1, :] - images[:, : ,1:, :]))\n","        w_var = torch.sum(torch.abs(images[:, :, :, :-1] - images[:, :, :, 1:]))\n","    else:  # if 3 (CHW)\n","        h_var = torch.sum(torch.abs(images[:, :-1, :] - images[: ,1:, :]))\n","        w_var = torch.sum(torch.abs(images[:, :, :-1] - images[:, :, 1:]))\n","    return h_var + w_var\n","\n","def entropy(sm_tensor, epsilon=1e-10):\n","    \"\"\"\n","    Returns a N length vector of entropies from an NxC tensor.\n","    \"\"\"\n","    log_sm_tensor = torch.log(sm_tensor+epsilon)\n","    h = -torch.sum(sm_tensor * log_sm_tensor, dim=1)  # formula for entropy\n","    return h\n","\n","def custom_loss_patch_adv(output, target, patch, lam_xent=3.0, lam_tvar=1e-3, \n","                          lam_disc=0.005, lam_patch_xent=0.2, lam_ent=0.2, quant=0.5, patch_bs=16):\n","    \"\"\"\n","    Calculates the targeted misclassification crossentropy loss with regularization based on \n","    total variation, discriminator realisticness confidence, classifier patch non-target confidence, \n","    and classifier patch entropy.\n","    \"\"\"\n","    avg_xent = nll_loss(torch.log(output), target)  # crossentropy (minimize)\n","    avg_tvar = total_variation(patch) / output.shape[0]  # avg total variation (minimize)\n","    loss = lam_xent*avg_xent + lam_tvar*avg_tvar\n","\n","    if lam_disc != 0:\n","        y = torch.tensor(list(range(N_CLASSES))).to(device)  # y for all classes\n","        disc_out = D(patch, y)[:, 0]  # class conditioned output for all 1000 classes\n","        disc_q = torch.quantile(disc_out, quant)  # quantile marking the k highest\n","        disc = torch.mean(disc_out[disc_out > disc_q]) # discriminator conf, mean is over top k (maximize)\n","        loss -= lam_disc*disc\n","\n","    if lam_patch_xent != 0 or lam_ent != 0:\n","        patch256 = resize256(patch)\n","        classifiers_out = E_reg(torch.cat([transforms_patch(patch256) for i in range(patch_bs)], axis=0)) # what the classifiers think of the patch\n","        patch_xent = nll_loss(torch.log(classifiers_out), target[:patch_bs])  # cross entropy loss for target (maximize)\n","        ent = torch.mean(entropy(classifiers_out)) # entropy for softmax outputs (minimize)\n","        loss -= lam_patch_xent*patch_xent\n","        loss += lam_ent*ent \n","\n","    return loss\n","\n","def custom_loss_region_gen_patch_adv(output, target, perturbation, adv_img, orig_img, lam_tvar=1e-5, lam_lpips=4, \n","                        lam_disc=0.1, lam_wd = 0.0001, lam_patch_xent=0.1, lam_ent=0.2):\n","    \"\"\"\n","    For region and generalized patch adversaries. \n","    Calculates the targeted misclassification crossentropy loss with regularization based on \n","    total variation, LPIPS perceptual distance, discriminator realisticness confidence, \n","    perturbation norm, classifier patch non-target confidence, and classifier patch entropy.\n","    \"\"\"\n","    avg_x_ent = nll_loss(torch.log(output), target) # crossentropy (minimize)\n","    n_imgs = output.shape[0]\n","    avg_t_var = total_variation(orig_img-adv_img) / n_imgs  # avg total variation (minimize)\n","    wd = torch.mean(torch.linalg.norm(torch.flatten(perturbation, start_dim=1), dim=1))  # L2 perturbation norm (minimize)\n","    loss = avg_x_ent + lam_tvar*avg_t_var + lam_wd * wd\n","    \n","    if lam_lpips != 0:\n","        avg_lpips = lpips_dist(adv_img, orig_img) / n_imgs  # lpips perceptual distance (minimize)\n","        loss += lam_lpips * torch.squeeze(avg_lpips)\n","\n","    if lam_disc != 0:\n","        y = torch.tensor([list(range(N_CLASSES))]*adv_img.shape[0]).to(device) \n","        avg_disc = torch.mean(torch.topk(D(resize128(adv_img), y), 5, dim=0)[0])  # avg across top-5 disc values and across minibatch (maximize)\n","        loss -= lam_disc * avg_disc\n","\n","    if lam_patch_xent != 0 or lam_ent != 0:\n","        patches = [crop_to_square(get_gen_patch(*pair)) for pair in zip(orig_img, adv_img)]\n","        patches256 = torch.cat([resize256(patch) for patch in patches])  # get patches as full ims\n","        classifiers_out = E_reg(normalize(patches256))\n","        avg_patch_xent = nll_loss(torch.log(classifiers_out), target)  # classifier xent for *target* class (maximize)\n","        avg_ent = torch.mean(entropy(classifiers_out)).item()  # classifier softmax entropy (minimize)\n","        loss -= lam_patch_xent * avg_patch_xent \n","        loss += lam_ent * avg_ent\n","\n","    return loss\n","\n","def insert_patch(patch, batch_size, prop_lower=0.2, prop_upper=0.8, side_radius=10, transform=True, from_generator=False, y=None):\n","    \"\"\"\n","    For universal patch attacks, this randomly tiles images and inserts patches into them.\n","    \"\"\"\n","    if from_generator:  # if generating your own patches\n","        with torch.no_grad():\n","            ys = torch.cat([y]*batch_size, 0)\n","            rand_noises = G.sample_latent(batch_size=batch_size, device=device)\n","            images = normalize(G(rand_noises, ys))\n","            orig_images = copy.deepcopy(images).to(device)\n","    else:  # if using ImageNet validation set images\n","        rand_is = np.random.randint(0, imagenet2k_labels.shape[0], size=batch_size)\n","        images = normalize(torch.stack([to_tensor(imagenet2k[rand_i]) for rand_i in rand_is])).to(device)\n","        orig_images = copy.deepcopy(images).to(device)\n","    mid = (IMAGE_SIDE-PATCH_SIDE) // 2\n","    for i in range(batch_size): \n","        if transform:  # randomly transform and insert\n","            side = np.random.randint(PATCH_SIDE-side_radius, PATCH_SIDE+side_radius+1)\n","            rand_x = np.random.randint(int((IMAGE_SIDE-side)*prop_lower), \n","                                    int((IMAGE_SIDE-side)*prop_upper)+1)\n","            rand_y = np.random.randint(int((IMAGE_SIDE-side)*prop_lower), \n","                                    int((IMAGE_SIDE-side)*prop_upper)+1)\n","            to_insert = transforms_patch(T.functional.resize(patch, [side, side]))\n","            mask = to_insert != 0.0  # the mask makes any black parts of the patch not inserted\n","            images[i, :, rand_x: rand_x+side, rand_y: rand_y+side] *= torch.logical_not(mask)\n","            images[i, :, rand_x: rand_x+side, rand_y: rand_y+side] += mask * to_insert\n","        else:  # randomly insert\n","            rand_x = np.random.randint(int((IMAGE_SIDE-PATCH_SIDE)*prop_lower), \n","                                       int((IMAGE_SIDE-PATCH_SIDE)*prop_upper)+1)\n","            rand_y = np.random.randint(int((IMAGE_SIDE-PATCH_SIDE)*prop_lower), \n","                                       int((IMAGE_SIDE-PATCH_SIDE)*prop_upper)+1)\n","            images[i, :, rand_x: rand_x+PATCH_SIDE, rand_y: rand_y+PATCH_SIDE] = resize64(patch)\n","    return images, orig_images\n","\n","def get_mask(orig, adv, quant_threshold=0.9):\n","    \"\"\"\n","    For generalized patch attacks. Takes in two tensors, produces a bool mask tensor of their differences\n","    \"\"\"\n","    diff = tensor_to_numpy_image(tensor_to_0_1(adv-orig), False)\n","    smooth_absdiff = ndimage.gaussian_filter(np.abs(diff-0.5), 12)\n","    mask =  smooth_absdiff > np.quantile(smooth_absdiff, quant_threshold)\n","    mask = np.any(mask, axis=-1) # differences on each color channel merged\n","    mask = ndimage.binary_opening(mask, iterations=4)\n","    mask = ndimage.binary_closing(mask, iterations=4, border_value=1)\n","    return torch.tensor(mask, device=device) \n","\n","def crop_to_square(patch):\n","    \"\"\"\n","    Takes a patch over a grey background and condenses it to a minimal bounding square. \n","    \"\"\"\n","    mask = patch[0,0] != 0.5  # just use the R channel as a heuristic\n","    adv_region = patch[np.ix_([True], [True, True, True], torch.any(mask, dim=1).cpu().numpy(), torch.any(mask, dim=0).cpu().numpy())]\n","    sh = adv_region.shape\n","    square = torch.ones((sh[0], sh[1], max([sh[2], sh[3]]), max([sh[2], sh[3]])), device=device) * 0.5\n","    square[:, :, 0:sh[2], 0:sh[3]] = adv_region\n","    return square\n","\n","def get_gen_patch(orig, adv):\n","    \"\"\"\n","    For generlized patch attacks. Returns a patch of the diff between the adv and orig imgs over a gray background\n","    \"\"\"\n","    mask = get_mask(orig, adv)\n","    patch = adv * mask[None, None, :, :]\n","    patch += 0.5 * torch.ones_like(adv) * torch.logical_not(mask[None, None, :, :])\n","    return patch   \n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"dvXooUgkLNX5"},"source":["### Attack Training and Evaluation Functions\n","Where the magic happens. These functions perform patch, region, generalized-patch, and copy-paste attacks."]},{"cell_type":"code","metadata":{"id":"vIdDpnEOCimF"},"source":["def patch_adversary(n_batches=64, batch_size=32, lr=0.01, latent_i=8, \n","                    source_class=None, target_class=None, loss_hypers={}):\n","    \"\"\"\n","    This function trains an adversarial patch that is targeted, universal, interpretable, and \n","    physically-realizable. The success rate is variable for random choices of target classes, \n","    so try running it multiple times. \n","    \"\"\"\n","    # get target class\n","    if target_class is None:\n","        target_class = np.random.randint(N_CLASSES)\n","    target_tensor = torch.tensor([target_class]*batch_size, dtype=torch.long).to(device)\n","\n","    # if a class universal adversary\n","    if source_class is not None:\n","        source_tmp = G.sample_class(batch_size=1, device=device) * 0\n","        source_tmp[0][source_class] += 1\n","        source = source_tmp\n","\n","    # get latents from the patch generaor\n","    with torch.no_grad():\n","        y = G.sample_class(batch_size=1, device=device)\n","        orig_noise = G.sample_latent(batch_size=1, device=device)\n","        latents = G(orig_noise, y, return_latents=True)\n","        opt_latent = nn.Parameter(torch.clone(latents[latent_i]))\n","        optimizer = optim.Adam([opt_latent], lr=lr)\n","\n","    # generate patch, insert into images, and train\n","    for _ in tqdm(range(n_batches)):\n","        patch = normalize(G(opt_latent, y, original_noise=orig_noise, insertion_layer=latent_i))\n","        if source_class is None:  # if a universal attack\n","            patched_images, orig_images = insert_patch(patch[0], batch_size)\n","        else:  # if a class_universal attack\n","            patched_images, orig_images = insert_patch(patch[0], batch_size, from_generator=True, y=source)\n","        predictions = E_attack(patched_images)\n","        optimizer.zero_grad()\n","        loss = custom_loss_patch_adv(predictions, target_tensor, patch, **loss_hypers)\n","        loss.backward()\n","        optimizer.step()\n","\n","    # evaluate\n","    with torch.no_grad():\n","        patch = normalize(G(opt_latent, y, original_noise=orig_noise, insertion_layer=latent_i))\n","        if source_class is None:  # if a universal attack\n","            patched_images, _ = insert_patch(patch[0], batch_size) \n","        else:\n","            patched_images, _ = insert_patch(patch[0], batch_size, from_generator=True, y=source)\n","        adv_sm_out = E_attack(patched_images)\n","        mean_conf = round(np.mean(np.array([float(aso[target_class]) for aso in adv_sm_out])), N_ROUND)\n","        i_sm_out = E_reg(resize256(patch))\n","        i_class = int(torch.argmax(i_sm_out))\n","        i_conf = round(float(torch.max(i_sm_out)), N_ROUND)\n","        \n","    # show results\n","    plt.imshow(tensor_to_numpy_image(patch[0]))\n","    if source_class is None:\n","        plt.title(f'Universal Patch Adversary\\nlatent: {latent_i}\\ntarget: {class_dict[target_class]}, mean conf: {mean_conf}\\ndisguise: {class_dict[i_class]}, conf: {i_conf}'.title())\n","    else: \n","        plt.title(f'Class Universal Patch Adversary\\nlatent: {latent_i}\\nsource={class_dict[source_class]}\\ntarget: {class_dict[target_class]}, mean conf: {mean_conf}\\ndisguise: {class_dict[i_class]}, conf: {i_conf}'.title())\n","    plt.xticks([])\n","    plt.yticks([])\n","    plt.show()    \n","\n","def assess_gp(patches, target_int, n_test=3, n_display=3, source_class=None):\n","    \"\"\"\n","    This function is called from inside of assess_rgp and displayes the generalized patches.\n","    \"\"\"\n","    with torch.no_grad():\n","        # Classes/noise for GAN (very finicky, change at your peril)\n","        y = G.sample_class(batch_size=n_test, device=device)\n","        if source_class is not None:\n","            y *= 0\n","            y[:,source_class] += 1\n","        y_int = torch.argmax(y, -1).detach().cpu().numpy()\n","        orig_noise = G.sample_latent(batch_size=n_test, device=device)\n","        orig_target =  torch.zeros(n_test)\n","\n","        # Set up fig and some stats tensors\n","        n_display = min([n_display, n_test])\n","        fig, axes = plt.subplots(1 + len(patches), 1 + n_display, figsize=(4*(1 + n_display), 5*(1 + len(patches))))\n","        gp_target = torch.zeros((len(patches), n_test))\n","        gp_mean_conf = torch.zeros(len(patches))\n","        gp_std_conf = torch.zeros(len(patches))\n","\n","        # Fills the first row with the original generated images\n","        orig_imgs = []\n","        for j in range(n_test):\n","            orig_img = G(orig_noise[[j]], y[[j]])\n","            orig_imgs.append(orig_img)\n","            if j < n_display:\n","                orig_sm_out = E_attack(normalize(orig_img))[0]\n","                orig_target[j] = round(float(orig_sm_out[target_int]), N_ROUND)\n","                axes[0, j+1].imshow(tensor_to_numpy_image(orig_img, False))\n","                axes[0, j+1].set_title(f'{class_dict[y_int[j]]}: {round(float(orig_sm_out[y_int[j]]), N_ROUND)}\\n {class_dict[target_int]}: {round(orig_target[j].item(), N_ROUND)}'.title(), fontweight=\"bold\")\n","\n","        # Fill out each successive row with each patch's results\n","        for i, patch in enumerate(patches):\n","            mask = patch != 0.5\n","            for j, orig_img in enumerate(orig_imgs):\n","                gp_img = orig_img * torch.logical_not(mask) + patch * mask\n","                gp_sm_out = E_attack(normalize(transforms_im(gp_img)))[0]\n","                gp_target[i, j] = round(float(gp_sm_out[target_int]), N_ROUND)\n","                if j < n_display:\n","                    axes[i+1, j+1].imshow(tensor_to_numpy_image(gp_img, False))\n","                    axes[i+1, j+1].set_title(f'{class_dict[y_int[j]]}: {round(float(gp_sm_out[y_int[j]]), N_ROUND)}\\n'.title() +\n","                                             f'{class_dict[target_int]}: {round(float(gp_sm_out[target_int]), N_ROUND)}'.title(), fontweight=\"bold\")\n","            \n","            gp_mean_conf[i] = torch.mean(gp_target[i,:])\n","            gp_std_conf[i] = torch.std(gp_target[i,:])\n","            square = crop_to_square(patch)\n","            \n","            # For fully grey composite patches (occur when there's no overlap in the generated patched)\n","            if torch.numel(square) <= 1:\n","                square = patch\n","\n","            # Patches are evaluated on their own with the reg classifier\n","            reg_out = E_reg(normalize(resize256(square))).squeeze(0)\n","            axes[i+1, 0].imshow(tensor_to_numpy_image(resize256(square), False))\n","            axes[i+1, 0].set_title(f'Disguise conf ({class_dict[torch.argmax(reg_out).item()]}): {round(float(torch.max(reg_out)), N_ROUND)}\\n'.title() +\n","                                   f'Mean target conf: {round(float(gp_mean_conf[i]), N_ROUND)}\\n'.title() + \n","                                   f'Std target conf: {round(float(gp_std_conf[i]), N_ROUND)}'.title(), fontweight=\"bold\")\n","\n","        axes[0,0].axis('off')\n","        for ax in axes.flatten():\n","            ax.set_xticks([])\n","            ax.set_yticks([])\n","        plt.show()\n","\n","def assess_rgp(modification, layer, modify_fn, target_int, n_test=20, n_display=3, metadata=None, source_class=None):\n","    \"\"\"\n","    This function asses and displays results a region attack and calls assess_gp to do so for \n","    the corresponding generalized patch attack. \n","    \"\"\"\n","    with torch.no_grad():\n","        # Set up fig\n","        n_display = min([n_display, n_test])\n","        fig, axes = plt.subplots(n_display, 3, figsize=(15, 5*n_display))\n","\n","        # Classes/noise for GAN (very finicky, change at your peril)\n","        y = G.sample_class(batch_size=n_test, device=device)\n","        if source_class is not None:\n","            y *= 0\n","            y[:,source_class] += 1\n","        y_int = torch.argmax(y, -1).detach().cpu().numpy()\n","        orig_noise = G.sample_latent(batch_size=n_test, device=device)\n","        orig_target, adv_target, gps = [], [], []  # lists to save target confidences and generalized patches\n","\n","        # Generate and display some images; save ram by running the GAN with one image at a time\n","        for i in range(n_test): \n","            orig_latents = G(orig_noise[[i]], y[[i]], return_latents=True)\n","            orig_img = orig_latents[-1]\n","            adv_latent = modify_fn(torch.clone(orig_latents[layer]), modification)\n","            adv_img = G(adv_latent, y[[i]], original_noise=orig_noise[[i]], insertion_layer=layer)\n","            adv_sm_out = E_attack(normalize(adv_img))[0]\n","            adv_target.append(round(float(adv_sm_out[target_int]), N_ROUND))\n","            gps.append(get_gen_patch(orig_img, adv_img))   \n","\n","            # display the first n_display examples\n","            if i < n_display:\n","                orig_sm_out = E_attack(normalize(orig_img))[0]\n","                orig_target.append(round(float(orig_sm_out[target_int]), N_ROUND))\n","                axes[i, 0].imshow(tensor_to_numpy_image(orig_img, False))\n","                axes[i, 0].set_title(f'{class_dict[y_int[i]].title()}: {round(float(orig_sm_out[y_int[i]]), N_ROUND)}\\n {class_dict[target_int]}: {orig_target[i]}'.title(), fontweight = \"bold\")\n","                axes[i, 1].imshow(tensor_to_numpy_image(adv_img.squeeze(0), False))\n","                axes[i, 1].set_title(f'{class_dict[y_int[i]].title()}: {round(float(adv_sm_out[y_int[i]]), N_ROUND)}\\n {class_dict[target_int]}: {adv_target[i]}'.title(), fontweight = \"bold\")\n","                axes[i, 2].imshow(tensor_to_numpy_image(tensor_to_0_1(adv_img.squeeze(0)-orig_img), False))\n","                axes[i, 2].set_title(f'Normalized pixel-level diff'.title(), fontweight = \"bold\") \n","\n","        fig.suptitle('Latent ' + str(layer) + \n","                     '\\nMean target confidence: ' + str(round(np.mean(adv_target), N_ROUND)) +\n","                     '\\nStd target confidence: ' + str(round(np.std(adv_target), N_ROUND)) + \n","                     (f'\\nLoss hyperparameters: {metadata[\"loss_hypers\"]}' if metadata is not None else \"\"), fontweight=\"bold\")\n","\n","        for ax in axes.flatten():\n","            ax.set_xticks([])\n","            ax.set_yticks([])\n","        plt.show()\n","\n","        # Make a composite patch by finding the regions that are perturbed in >80% of the patches\n","        mask_most = torch.sum(torch.stack([(gp[0,0] != 0.5) for gp in gps]), dim=0) > (0.8 * len(gps))\n","        patch_avg = torch.mean(torch.cat([gp for gp in gps]), dim=0) \n","        patch_comp = patch_avg * mask_most[None, None, :, :] + \\\n","                     0.5 * torch.ones_like(patch_avg) * torch.logical_not(mask_most[None, None, :, :])\n","        gps.append(patch_comp)  # add it to patch list\n","\n","        # Next, assess the generalized patches\n","        assess_gp(gps[(len(gps)-n_display):], target_int, source_class=source_class)\n","\n","def region_generalized_patch_adversary(prop_modified=1/8, latent_i=6, n_batches=128, batch_size=32, sub_batch_size=8, \n","                                       lr=0.05, source_class=None, target_class=None, loss_hypers={}):\n","    \"\"\"\n","    This function trains region and generalized patch attacks that are targeted, universal, \n","    and interpretable. The success rate is variable for random choices \n","    of target classes, so try running it multiple times. \n","    \"\"\"\n","\n","    # Fix batch size if needed\n","    batch_size -= batch_size % sub_batch_size\n","\n","    # Get target class for the attack\n","    if target_class is None:\n","        target_class = np.random.randint(N_CLASSES)\n","    target = torch.tensor(target_class, dtype=torch.long, device=device).unsqueeze(0)\n","    \n","    # If class-universal, set y (the source class) permanently\n","    if source_class is not None:\n","        y_int = torch.tensor(sub_batch_size * [source_class], device=device)\n","        y = torch.tensor(one_hot_from_int(sub_batch_size * [source_class], batch_size=sub_batch_size), device=device)\n","  \n","    # Get a sample pass through the generator and get params for the attack\n","    y_init = G.sample_class(batch_size=2, device=device)\n","    noise_init = G.sample_latent(batch_size=2, device=device)\n","    latent_init = G(noise_init, y_init, original_noise=noise_init, return_latents=True)[latent_i]\n","    region_side = int(np.sqrt(prop_modified) * latent_init.shape[-1])\n","    reg_x = np.random.randint(latent_init.shape[-1] - region_side + 1)\n","    reg_y = np.random.randint(latent_init.shape[-1] - region_side + 1)\n","\n","    # The modification parameterizes the parturbation\n","    modification = nn.Parameter(torch.zeros((latent_init.shape[1], region_side, region_side), device=device))\n","    optimizer = optim.Adam([modification], lr=lr)\n","\n","    # This function applies the perturbation \n","    def modify_fn(latent, perturbation):\n","        for i in range(latent.shape[0]):\n","            latent[i, :, reg_x:(reg_x+region_side), reg_y:(reg_y+region_side)] = perturbation\n","        return latent\n","    \n","    # Train the modification\n","    for step in tqdm(range(n_batches), position=0, leave=True):\n","        for batch_i in range(batch_size//sub_batch_size):  # Avoids overtaxing GPUs\n","            with torch.no_grad():\n","\n","                # Sample some source classes if it's a universal attack (and not a class universal one)\n","                if source_class is None:\n","                    y = G.sample_class(batch_size=sub_batch_size, device=device)\n","                    y_int = torch.argmax(y, -1)\n","\n","                # Generate a sub-batch of images and their latents\n","                orig_noise = G.sample_latent(batch_size=sub_batch_size, device=device)\n","                orig_latents = G(orig_noise, y, original_noise=orig_noise, return_latents=True)\n","                orig_latent = orig_latents[latent_i]\n","                orig_imgs = orig_latents[-1]\n","\n","            # calc loss and backward\n","            adv_latent = modify_fn(torch.clone(orig_latent), modification) \n","            adv_imgs = G(adv_latent, y, original_noise=orig_noise, insertion_layer=latent_i)\n","            adv_prediction = E_attack(normalize(transforms_im(adv_imgs)))\n","            loss = custom_loss_region_gen_patch_adv(adv_prediction, torch.tile(target, (sub_batch_size,)), \n","                                                    adv_latent-orig_latent, adv_imgs, orig_imgs, **loss_hypers)\n","            loss.backward()\n","        \n","        # optimize\n","        optimizer.step()\n","        optimizer.zero_grad() \n","\n","    assess_rgp(modification, latent_i, modify_fn, target.item(), source_class=source_class) \n","\n","def copy_paste_attack(source_file, patch_file, patch_side=85, prop_lower=0.2, prop_upper=0.8):\n","\n","    source_path = Path(f'data/{source_file}')\n","    source_im = resize256(normalize(to_tensor(image.imread(source_path)[:, :, :3]))).to(device)\n","    patch_im = copy.deepcopy(source_im)\n","\n","    patch_path = Path(f'data/{patch_file}')\n","    patch = normalize(to_tensor(image.imread(patch_path)[:, :, :3])).to(device)\n","    mid = (IMAGE_SIDE-patch_side) // 2\n","    diff = IMAGE_SIDE-patch_side\n","    rand_x = np.random.randint(int(diff*prop_lower), int(diff*prop_upper)+1)\n","    rand_y = np.random.randint(int(diff*prop_lower), int(diff*prop_upper)+1)\n","    patch_im[:, rand_x: rand_x+patch_side, rand_y: rand_y+patch_side] = T.functional.resize(patch, [patch_side, patch_side])\n","    \n","    orig_sm_out = E_attack(torch.unsqueeze(source_im, 0))[0]\n","    orig_label = torch.argmax(orig_sm_out)\n","    orig_conf = torch.max(orig_sm_out)\n","    patch_sm_out = E_attack(torch.unsqueeze(patch_im, 0))[0]\n","    patch_label = torch.argmax(patch_sm_out)\n","    patch_conf = torch.max(patch_sm_out)\n","\n","    fig, axes = plt.subplots(1, 2, figsize=(8, 4))\n","    axes[0].imshow(tensor_to_numpy_image(source_im))\n","    axes[0].set_title(f'{class_dict[orig_label.item()]}: {round(orig_conf.item(), N_ROUND)}'.title(), fontweight='bold')\n","    axes[1].imshow(tensor_to_numpy_image(patch_im))\n","    axes[1].set_title(f'{class_dict[patch_label.item()]}: {round(patch_conf.item(), N_ROUND)}'.title(), fontweight='bold')\n","    for ax in axes:\n","        ax.set_xticks([])\n","        ax.set_yticks([])\n","    plt.show()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"BQZJXNWqsZkk"},"source":["### Demo"]},{"cell_type":"code","metadata":{"id":"PKX6x4qRCuum"},"source":["# Generate universal patch attacks. \n","# They have variable success for random target classes, so run multiple times. \n","for _ in range(3):\n","    patch_adversary()"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"h6w3pe9nFRvI"},"source":["# Generate class-universal patch attacks (using generated source images rather than real ones)\n","for _ in range(3):\n","    patch_adversary(source_class=309, target_class=308)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"ExAuDMGoE181"},"source":["# Generate universal region and generalized patch attacks. \n","# They have variable success for random target classes, so run multiple times. \n","for _ in range(3):\n","    region_generalized_patch_adversary()"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"lCQV3hSAFMLg"},"source":["# Generate class-universal region and generalized patch attacks.\n","for _ in range(3):\n","    region_generalized_patch_adversary(source_class=397, target_class=396)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"0aBj2OzVCzLE"},"source":["# Simple function call to make a copy/paste attack. Upload your own images to create new ones. \n","for _ in range(3): \n","    copy_paste_attack('bee.png', 'traffic_light.png')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"vwuDY53Lth6z"},"source":["### Tips\n","- If you get any CUDA out of memory issues, try reducing the ```batch_size``` or ```sub_batch_size``` arguments to the attack functions. \n","- If you want things to run more quickly, you can often get away with reducing ```n_batches```.\n","- Targeted, universal attacks tend to have have variable success, especially when optimizing for a complex objective as in our case. So always run multiple trials. \n","- Attacks tend to be the easiest to produce when using semantically-related source/target class pairs such as bee and fly or pufferfish and lionfish. \n","- You can modify the ```latent_i``` param to change which block of the generator the perturbation is trained in. Using the very last one (```latent_i=13```) will result in a standard pixel-space attack. \n","- You can play around with the loss hyperparameters to get attacks optimized more or less for different parts of the objective. \n","- This code should be fairly easy to modify for your own experiments. The key functions to play with will be ```patch_adversary```, ```region_generalized_patch_adversary```, ```custom_loss_patch_adv```, and ```custom_loss_region_gen_patch_adv```. "]}]}