{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 99,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow\n",
    "import tensorflow.compat.v1 as tf\n",
    "import numpy as np\n",
    "from setup_cifar import CIFAR, CIFARModel\n",
    "from PIL import Image\n",
    "import Utils_CIFAR as util\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import random\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_batch_gradient(delta, imgs_idx, model, data, q, d, miu, s2):\n",
    "    gradient_total = np.zeros_like(delta)\n",
    "    for image_id in imgs_idx:\n",
    "        _, orig_class, _ = util.model_prediction(model, np.expand_dims(data.test_data[image_id],axis=0))\n",
    "        target_label = orig_class\n",
    "        orig_img, y = util.generate_data(data,image_id,target_label)\n",
    "        true_label_list = np.argmax(data.test_labels, axis=1)\n",
    "        true_label = true_label_list[image_id]\n",
    "\n",
    "        gradient_i = np.zeros((q,d))\n",
    "        for i in range(q):\n",
    "            X_adv = orig_img + delta\n",
    "            u = util.generate_u(s2, d)\n",
    "            predict_1 = util.f(model, X_adv+miu*u,true_label)\n",
    "            predict_2 = util.f(model, X_adv,true_label)\n",
    "            u = np.reshape(u, (1, d))\n",
    "            gradient_i[i] = d/miu*(predict_1-predict_2)*u\n",
    "        gradient = np.sum(gradient_i,axis=0)/q\n",
    "        gradient = np.reshape(gradient,(1,32,32,3))\n",
    "        gradient_total += gradient\n",
    "    return gradient_total / len(imgs_idx)\n",
    "    \n",
    "def compute_batch_loss(delta, imgs_idx, model, data):\n",
    "    total_loss = 0\n",
    "    for image_id in imgs_idx:\n",
    "        _, orig_class, _ = util.model_prediction(model, np.expand_dims(data.test_data[image_id],axis=0))\n",
    "        target_label = orig_class\n",
    "        orig_img, y = util.generate_data(data,image_id,target_label)\n",
    "        true_label_list = np.argmax(data.test_labels, axis=1)\n",
    "        true_label = true_label_list[image_id]\n",
    "        X_adv = orig_img + delta\n",
    "        adv_image = np.clip(X_adv,-0.5,0.5)\n",
    "        total_loss += util.f(model, adv_image,true_label)\n",
    "    return total_loss / len(imgs_idx)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "from collections import namedtuple\n",
    "Args = namedtuple('Args', ['iter', 'miu', 'd', 'k', 'eta', 'q', 's2', 'lname'])\n",
    "\n",
    "args = Args(iter=100, miu=0.001, q=10, s2=3072, eta=0.01, k=60, d=3072, lname='./CIFAR/result.txt')\n",
    "zomax = 600\n",
    "m = 10\n",
    "np.random.seed(42)\n",
    "random.seed(42)\n",
    "tf.set_random_seed(42)\n",
    "RS = np.random.RandomState(42)\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    data, model = CIFAR(), CIFARModel('./models/cifar',sess,True)\n",
    "\n",
    "    # take all the images of dogs: \n",
    "    imgs_idx = [65, 67, 70, 75, 86, 113, 123, 129, 138, 149] #bird\n",
    "    use_log = True\n",
    "\n",
    "    succ_count, ii, iii = 0, 0, 0\n",
    "\n",
    "    image_number = len(imgs_idx)\n",
    "    l2_distortion_collect = np.zeros(image_number)\n",
    "    attack_succ_count = np.zeros(image_number)\n",
    "    cc = 0\n",
    "    cc2 = 0\n",
    "\n",
    "    delta = np.zeros((1,32,32,3))\n",
    "    hist_loss = []\n",
    "    hist_zo = []\n",
    "    hist_nht = []\n",
    "    nizo = 0\n",
    "    nht = 0\n",
    "    beta = 0.65\n",
    "\n",
    "    total_loss = compute_batch_loss(delta, imgs_idx, model, data)\n",
    "    print(f\"total loss: {total_loss}\")\n",
    "    hist_loss.append(total_loss.item())\n",
    "    hist_zo.append(nizo)\n",
    "    hist_nht.append(nht)\n",
    "\n",
    "\n",
    "    while nizo < zomax:\n",
    "        inner_its = 0\n",
    "        anchor = delta + 0.\n",
    "\n",
    "        full_grad = compute_batch_gradient(anchor, imgs_idx, model, data, args.q, args.d, args.miu, args.s2)\n",
    "        nizo += len(imgs_idx) * (args.q + 1) \n",
    "\n",
    "\n",
    "        while (inner_its < m) and (nizo < zomax):\n",
    "            attack_flag = False\n",
    "            image_id = RS.choice(imgs_idx)\n",
    "            orig_prob, orig_class, orig_prob_str = util.model_prediction(model,\n",
    "                                                                            np.expand_dims(data.test_data[image_id],axis=0))\n",
    "            target_label = orig_class\n",
    "            orig_img, target = util.generate_data(data,image_id,target_label)\n",
    "            true_label_list = np.argmax(data.test_labels, axis=1)\n",
    "            true_label = true_label_list[image_id]\n",
    "            with open(args.lname,'a+') as f:\n",
    "                f.write(\"\\n Image ID:{}, infer label:{}, true label:{} \\n\".format(image_id, orig_class, true_label))\n",
    "            print(\"Image ID:{}, infer label:{}, true label:{}\".format(image_id, orig_class, true_label))\n",
    "            if true_label != orig_class:\n",
    "                raise(\"True Label is different from the original prediction, pass!\")\n",
    "\n",
    "            count = 0\n",
    "\n",
    "            adv_image = orig_img + delta\n",
    "            current_gradient = util.compute_gradient(model, adv_image, true_label,args.s2,args.miu,args.q, args.d)\n",
    "            nizo += args.q + 1\n",
    "\n",
    "            anchor_gradient = util.compute_gradient(model, orig_img + anchor, true_label,args.s2,args.miu,args.q, args.d)\n",
    "            nizo += args.q + 1\n",
    "\n",
    "            delta_tmp = delta + 0.\n",
    "            delta_tmp = delta_tmp - args.eta * (beta*(current_gradient - anchor_gradient) + full_grad)\n",
    "\n",
    "            delta_tmp = np.reshape(delta_tmp, (args.d))\n",
    "            top_k_idx = np.argsort(-np.abs(delta_tmp))[0:args.k]\n",
    "            delta = np.zeros_like(delta_tmp)\n",
    "\n",
    "\n",
    "\n",
    "            delta[top_k_idx] = delta_tmp[top_k_idx]\n",
    "\n",
    "\n",
    "            inner_its += 1\n",
    "\n",
    "\n",
    "            nht += 1\n",
    "            l2_dist = np.linalg.norm(delta, ord=2, keepdims=False)\n",
    "            l0_num = 0\n",
    "            for dim in range(args.d):\n",
    "                if delta[dim] != 0:\n",
    "                    l0_num = l0_num + 1\n",
    "            l0_dist = l0_num / args.d\n",
    "\n",
    "            delta = np.reshape(delta, (1, 32,32,3))\n",
    "            adv_image = np.clip(orig_img + delta,-0.5,0.5)\n",
    "            attack_prob, attack_predict_class,_ = util.model_prediction(model, adv_image)\n",
    "            if (nizo + 1) % 1 == 0:\n",
    "                if true_label != attack_predict_class:\n",
    "                    with open(args.lname, 'a+') as f:\n",
    "                        f.write(\"Iter %d (Succ): ID = %d, l0_dist=%3.5f, l2_dist=%3.5f, TL = %d, PL = %d \\n\" % (\n",
    "                                nht + 1,image_id, l0_dist, l2_dist, true_label, attack_predict_class))\n",
    "                    print(\"Iter %d (Succ): ID = %d, l0_dist=%3.5f, l2_dist=%3.5f, TL = %d, PL = %d\" % (\n",
    "                        nht + 1,image_id, l0_dist, l2_dist, true_label, attack_predict_class))\n",
    "                    attack_flag = True\n",
    "                    count = count + 1\n",
    "                else:\n",
    "                    with open(args.lname, 'a+') as f:\n",
    "                        f.write(\"Iter %d (Fail): ID = %d, l0_dist=%3.5f, l2_dist=%3.5f, TL = %d, PL = %d \\n\" % (\n",
    "                                nht + 1,image_id, l0_dist, l2_dist, true_label, attack_predict_class))\n",
    "                    print(\"Iter %d (Fail): ID = %d, l0_dist=%3.5f, l2_dist=%3.5f, TL = %d, PL = %d\" % (\n",
    "                        nht + 1, image_id, l0_dist, l2_dist, true_label, attack_predict_class))\n",
    "                total_loss = compute_batch_loss(delta, imgs_idx, model, data)\n",
    "                print(f\"total loss: {total_loss}\")\n",
    "                hist_loss.append(total_loss.item())\n",
    "                hist_zo.append(nizo)\n",
    "                hist_nht.append(nht)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# at the end of training, we save the images in our folder\n",
    "label_list = {0: 'airplane',\n",
    "              1: 'automobile',\n",
    "              2: 'bird',\n",
    "              3: 'cat',\n",
    "              4: 'deer',\n",
    "              5: 'dog',\n",
    "              6: 'frog',\n",
    "              7: 'horse',\n",
    "              8: 'ship',\n",
    "              9: 'truck'}\n",
    "import os\n",
    "\n",
    "fig_path = './bvrzht/'\n",
    "if not os.path.exists(fig_path):\n",
    "    os.makedirs(fig_path)\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    data, model = CIFAR(), CIFARModel('./models/cifar',sess,True)\n",
    "\n",
    "    for i in imgs_idx:\n",
    "        orig_img, target = util.generate_data(data,i,target_label)\n",
    "        adv_image = np.clip(orig_img + delta,-0.5,0.5)\n",
    "        resh = np.reshape(adv_image, (32, 32, 3))\n",
    "\n",
    "\n",
    "\n",
    "        im = Image.fromarray(((resh + 0.5)*255).astype(np.uint8), \"RGB\")\n",
    "\n",
    "        _, orig_class, _ = util.model_prediction(model, np.expand_dims(orig_img[0],axis=0))\n",
    "        print(f'----- sample {i} --------')\n",
    "        print(f\"original class: {label_list[orig_class]}\")\n",
    "        _, new_class, _ = util.model_prediction(model, np.expand_dims(adv_image[0],axis=0))\n",
    "        print(f\"predicted class: {label_list[new_class]}\")\n",
    "        print(f'------------------------')\n",
    "\n",
    "        plt.figure()\n",
    "        plt.imshow(im)\n",
    "        plt.axis('off')\n",
    "        plt.savefig(f'./bvrzht/attacked_{i}.jpg', bbox_inches='tight')\n",
    "\n",
    "\n",
    "        adv_image = np.clip(orig_img,-0.5,0.5)\n",
    "        resh = np.reshape(adv_image, (32, 32, 3))\n",
    "        im = Image.fromarray(((resh + 0.5)*255).astype(np.uint8), \"RGB\")\n",
    "\n",
    "\n",
    "        plt.figure()\n",
    "        plt.imshow(im)\n",
    "        plt.axis('off')\n",
    "        plt.savefig(f'./bvrzht/original_{i}.jpg', bbox_inches='tight')\n",
    "\n",
    "        print(f\"l2 distortion: {np.linalg.norm(delta)} in input space\")\n",
    "\n",
    "\n",
    "print(delta)\n",
    "resh = np.reshape(delta, (32, 32, 3))\n",
    "im = Image.fromarray(((resh)*255).astype(np.uint8), \"RGB\")\n",
    "print(np.count_nonzero(delta))\n",
    "\n",
    "\n",
    "plt.figure()\n",
    "plt.imshow(im)\n",
    "plt.axis('off')\n",
    "plt.savefig('./bvrzht/perturb.jpg', bbox_inches='tight')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# here we plot the curves IZO and NHT: \n",
    "plt.figure()\n",
    "plt.plot(hist_zo, hist_loss, linestyle='-', marker='^', markersize=5, label=f\"lr: {args.eta}\")\n",
    "plt.legend()\n",
    "plt.ylabel('f(w)')\n",
    "plt.xlabel('# IZO')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure()\n",
    "plt.plot(hist_nht, hist_loss, linestyle='-', marker='^', markersize=5, label=f\"lr: {args.eta}\")\n",
    "plt.legend()\n",
    "plt.ylabel('f(w)')\n",
    "plt.xlabel('# NHT')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 105,
   "metadata": {},
   "outputs": [],
   "source": [
    "# pickle the results for further loading\n",
    "import pickle\n",
    "\n",
    "results = {'nizo': hist_zo, \n",
    "           'nht': hist_nht,\n",
    "           'hist': hist_loss}\n",
    "with open('./save/bvr_curves.pickle', 'wb') as file:\n",
    "    pickle.dump(results, file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "20cfd2a0d09ae7bc83c97a517227e395fcdaac136dcb90e0b3ddfd33806f9ad5"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
