{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-05-09T08:33:51.652472Z",
     "iopub.status.busy": "2023-05-09T08:33:51.651580Z",
     "iopub.status.idle": "2023-05-09T08:33:54.771926Z",
     "shell.execute_reply": "2023-05-09T08:33:54.770194Z"
    }
   },
   "outputs": [],
   "source": [
    "import tensorflow\n",
    "import tensorflow.compat.v1 as tf\n",
    "import numpy as np\n",
    "from setup_cifar import CIFAR, CIFARModel\n",
    "from PIL import Image\n",
    "import Utils_CIFAR as util\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import random\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-05-09T08:33:54.777916Z",
     "iopub.status.busy": "2023-05-09T08:33:54.777531Z",
     "iopub.status.idle": "2023-05-09T08:33:54.786206Z",
     "shell.execute_reply": "2023-05-09T08:33:54.784945Z"
    }
   },
   "outputs": [],
   "source": [
    "def compute_batch_gradient(delta, imgs_idx, model, data, q, d, miu, s2):\n",
    "    gradient_total = np.zeros(d)\n",
    "    for image_id in imgs_idx:\n",
    "        _, orig_class, _ = util.model_prediction(model, np.expand_dims(data.test_data[image_id],axis=0))  ## orig_class: predicted label;\n",
    "        target_label = orig_class\n",
    "        orig_img, y = util.generate_data(data,image_id,target_label)\n",
    "        true_label_list = np.argmax(data.test_labels, axis=1)\n",
    "        true_label = true_label_list[image_id]\n",
    "\n",
    "        gradient_i = np.zeros((q,d))\n",
    "        for i in range(q):\n",
    "            X_adv = orig_img + delta\n",
    "            u = util.generate_u(s2, d)\n",
    "            predict_1 = util.f(model, X_adv+miu*u,true_label)\n",
    "            predict_2 = util.f(model, X_adv,true_label)\n",
    "            u = np.reshape(u, (1, d))\n",
    "            gradient_i[i] = d/miu*(predict_1-predict_2)*u\n",
    "        gradient = np.sum(gradient_i,axis=0)/q\n",
    "        gradient = np.reshape(gradient,(1,32,32,3))\n",
    "        gradient_total += gradient\n",
    "    return gradient_total / len(imgs_idx)\n",
    "    \n",
    "def compute_batch_loss(delta, imgs_idx, model, data):\n",
    "    total_loss = 0\n",
    "    for image_id in imgs_idx:\n",
    "        _, orig_class, _ = util.model_prediction(model, np.expand_dims(data.test_data[image_id],axis=0))  ## orig_class: predicted label;\n",
    "        target_label = orig_class\n",
    "        orig_img, y = util.generate_data(data,image_id,target_label)\n",
    "        true_label_list = np.argmax(data.test_labels, axis=1)\n",
    "        true_label = true_label_list[image_id]\n",
    "        X_adv = orig_img + delta\n",
    "        adv_image = np.clip(X_adv,-0.5,0.5)\n",
    "        total_loss += util.f(model, adv_image,true_label)\n",
    "    return total_loss / len(imgs_idx)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-05-09T08:33:54.792542Z",
     "iopub.status.busy": "2023-05-09T08:33:54.791974Z",
     "iopub.status.idle": "2023-05-09T08:34:10.336680Z",
     "shell.execute_reply": "2023-05-09T08:34:10.335505Z"
    }
   },
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "from collections import namedtuple\n",
    "Args = namedtuple('Args', ['iter', 'miu', 'd', 'k', 'eta', 'q', 's2', 'lname'])\n",
    "args = Args(iter=100, miu=0.001, q=10, s2=3072, eta=0.005, k=60, d=3072, lname='./CIFAR/result.txt')\n",
    "zomax = 600\n",
    "np.random.seed(42)\n",
    "random.seed(42)\n",
    "tf.set_random_seed(42)\n",
    "RS = np.random.RandomState(42)\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    data, model = CIFAR(), CIFARModel('./models/cifar',sess,True)\n",
    "    imgs_idx = [65, 67, 70, 75, 86, 113, 123, 129, 138, 149]\n",
    "    use_log = True\n",
    "\n",
    "    succ_count, ii, iii = 0, 0, 0\n",
    "\n",
    "    image_number = len(imgs_idx)\n",
    "    l2_distortion_collect = np.zeros(image_number)\n",
    "    attack_succ_count = np.zeros(image_number)\n",
    "    cc = 0\n",
    "    cc2 = 0\n",
    "\n",
    "    delta = np.zeros((1,32,32,3))\n",
    "    hist_loss = []\n",
    "    hist_zo = []\n",
    "    hist_nht = []\n",
    "    nizo = 0\n",
    "    nht = 0\n",
    "\n",
    "    total_loss = compute_batch_loss(delta, imgs_idx, model, data)\n",
    "    print(f\"total loss: {total_loss}\")\n",
    "    hist_loss.append(total_loss.item())\n",
    "    hist_zo.append(nizo)\n",
    "    hist_nht.append(nht)\n",
    "\n",
    "    while nizo < zomax:\n",
    "        attack_flag = False\n",
    "        image_id = RS.choice(imgs_idx)\n",
    "        orig_prob, orig_class, orig_prob_str = util.model_prediction(model,\n",
    "                                                                        np.expand_dims(data.test_data[image_id],axis=0))\n",
    "        target_label = orig_class\n",
    "        orig_img, target = util.generate_data(data,image_id,target_label)\n",
    "        true_label_list = np.argmax(data.test_labels, axis=1)\n",
    "        true_label = true_label_list[image_id]\n",
    "        with open(args.lname,'a+') as f:\n",
    "            f.write(\"\\n Image ID:{}, infer label:{}, true label:{} \\n\".format(image_id, orig_class, true_label))\n",
    "        print(\"Image ID:{}, infer label:{}, true label:{}\".format(image_id, orig_class, true_label))\n",
    "        if true_label != orig_class:\n",
    "            raise(\"True Label is different from the original prediction, pass!\")\n",
    "        count = 0\n",
    "        adv_image = orig_img + delta\n",
    "        gradient = util.compute_gradient(model, adv_image, true_label,args.s2,args.miu,args.q, args.d)\n",
    "        nizo += args.q + 1\n",
    "        delta_tmp = delta + 0.\n",
    "        delta_tmp = delta_tmp - args.eta * gradient\n",
    "        delta_tmp = np.reshape(delta_tmp, (args.d))\n",
    "        top_k_idx = np.argsort(-np.abs(delta_tmp))[0:args.k]\n",
    "        delta = np.zeros_like(delta_tmp)\n",
    "\n",
    "        delta[top_k_idx] = delta_tmp[top_k_idx]\n",
    "        nht += 1\n",
    "        l2_dist = np.linalg.norm(delta, ord=2, keepdims=False)\n",
    "        l0_num = 0\n",
    "        for dim in range(args.d):\n",
    "            if delta[dim] != 0:\n",
    "                l0_num = l0_num + 1\n",
    "        l0_dist = l0_num / args.d\n",
    "\n",
    "        delta = np.reshape(delta, (1, 32,32,3))\n",
    "        adv_image = np.clip(orig_img + delta,-0.5,0.5)\n",
    "        attack_prob, attack_predict_class,_ = util.model_prediction(model, adv_image)\n",
    "        if (nizo + 1) % 1 == 0:\n",
    "            if true_label != attack_predict_class:\n",
    "                with open(args.lname, 'a+') as f:\n",
    "                    f.write(\"Iter %d (Succ): ID = %d, l0_dist=%3.5f, l2_dist=%3.5f, TL = %d, PL = %d \\n\" % (\n",
    "                            nht + 1,image_id, l0_dist, l2_dist, true_label, attack_predict_class))\n",
    "                print(\"Iter %d (Succ): ID = %d, l0_dist=%3.5f, l2_dist=%3.5f, TL = %d, PL = %d\" % (\n",
    "                    nht + 1,image_id, l0_dist, l2_dist, true_label, attack_predict_class))\n",
    "                attack_flag = True\n",
    "                count = count + 1\n",
    "            else:\n",
    "                with open(args.lname, 'a+') as f:\n",
    "                    f.write(\"Iter %d (Fail): ID = %d, l0_dist=%3.5f, l2_dist=%3.5f, TL = %d, PL = %d \\n\" % (\n",
    "                            nht + 1,image_id, l0_dist, l2_dist, true_label, attack_predict_class))\n",
    "                print(\"Iter %d (Fail): ID = %d, l0_dist=%3.5f, l2_dist=%3.5f, TL = %d, PL = %d\" % (\n",
    "                    nht + 1, image_id, l0_dist, l2_dist, true_label, attack_predict_class))\n",
    "            total_loss = compute_batch_loss(delta, imgs_idx, model, data)\n",
    "            print(f\"total loss: {total_loss}\")\n",
    "            hist_loss.append(total_loss.item())\n",
    "            hist_zo.append(nizo)\n",
    "            hist_nht.append(nht)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-05-09T08:34:10.343592Z",
     "iopub.status.busy": "2023-05-09T08:34:10.342241Z",
     "iopub.status.idle": "2023-05-09T08:34:16.801959Z",
     "shell.execute_reply": "2023-05-09T08:34:16.801151Z"
    }
   },
   "outputs": [],
   "source": [
    "label_list = {0: 'airplane',\n",
    "              1: 'automobile',\n",
    "              2: 'bird',\n",
    "              3: 'cat',\n",
    "              4: 'deer',\n",
    "              5: 'dog',\n",
    "              6: 'frog',\n",
    "              7: 'horse',\n",
    "              8: 'ship',\n",
    "              9: 'truck'}\n",
    "\n",
    "import os\n",
    "\n",
    "fig_path = './szoht/'\n",
    "if not os.path.exists(fig_path):\n",
    "    os.makedirs(fig_path)\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    data, model = CIFAR(), CIFARModel('./models/cifar',sess,True)\n",
    "\n",
    "    for i in imgs_idx:\n",
    "        orig_img, target = util.generate_data(data,i,target_label)\n",
    "        adv_image = np.clip(orig_img + delta,-0.5,0.5)\n",
    "        resh = np.reshape(adv_image, (32, 32, 3))\n",
    "\n",
    "\n",
    "\n",
    "        im = Image.fromarray(((resh + 0.5)*255).astype(np.uint8), \"RGB\")\n",
    "        _, orig_class, _ = util.model_prediction(model, np.expand_dims(orig_img[0],axis=0))\n",
    "        print(f'----- sample {i} --------')\n",
    "        print(f\"original class: {label_list[orig_class]}\")\n",
    "        _, new_class, _ = util.model_prediction(model, np.expand_dims(adv_image[0],axis=0))\n",
    "        print(f\"predicted class: {label_list[new_class]}\")\n",
    "        print(f'------------------------')\n",
    "\n",
    "        plt.figure()\n",
    "        plt.imshow(im)\n",
    "        plt.axis('off')\n",
    "        plt.savefig(f'./szoht/attacked_{i}.jpg', bbox_inches='tight')\n",
    "\n",
    "\n",
    "        adv_image = np.clip(orig_img,-0.5,0.5)\n",
    "        resh = np.reshape(adv_image, (32, 32, 3))\n",
    "        im = Image.fromarray(((resh + 0.5)*255).astype(np.uint8), \"RGB\")\n",
    "\n",
    "\n",
    "        plt.figure()\n",
    "        plt.imshow(im)\n",
    "        plt.axis('off')\n",
    "        plt.savefig(f'./szoht/original_{i}.jpg', bbox_inches='tight')\n",
    "        print(f\"l2 distortion: {np.linalg.norm(delta)} in input space\")\n",
    "\n",
    "\n",
    "print(delta)\n",
    "resh = np.reshape(delta, (32, 32, 3))\n",
    "im = Image.fromarray(((resh)*255).astype(np.uint8), \"RGB\")\n",
    "print(np.count_nonzero(delta))\n",
    "\n",
    "plt.figure()\n",
    "plt.imshow(im)\n",
    "plt.axis('off')\n",
    "plt.savefig('./szoht/perturb.jpg', bbox_inches='tight')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-05-09T08:34:16.807129Z",
     "iopub.status.busy": "2023-05-09T08:34:16.806769Z",
     "iopub.status.idle": "2023-05-09T08:34:16.907138Z",
     "shell.execute_reply": "2023-05-09T08:34:16.906342Z"
    }
   },
   "outputs": [],
   "source": [
    "plt.figure()\n",
    "plt.plot(hist_zo, hist_loss, linestyle='-', marker='^', markersize=5, label=f\"lr: {args.eta}\")\n",
    "plt.legend()\n",
    "plt.ylabel('f(w)')\n",
    "plt.xlabel('# IZO')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-05-09T08:34:16.912812Z",
     "iopub.status.busy": "2023-05-09T08:34:16.912495Z",
     "iopub.status.idle": "2023-05-09T08:34:17.011917Z",
     "shell.execute_reply": "2023-05-09T08:34:17.011145Z"
    }
   },
   "outputs": [],
   "source": [
    "plt.figure()\n",
    "plt.plot(hist_nht, hist_loss, linestyle='-', marker='^', markersize=5, label=f\"lr: {args.eta}\")\n",
    "plt.legend()\n",
    "plt.ylabel('f(w)')\n",
    "plt.xlabel('# NHT')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-05-09T08:34:17.017869Z",
     "iopub.status.busy": "2023-05-09T08:34:17.017481Z",
     "iopub.status.idle": "2023-05-09T08:34:17.022604Z",
     "shell.execute_reply": "2023-05-09T08:34:17.021754Z"
    }
   },
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "results = {'nizo': hist_zo, \n",
    "           'nht': hist_nht,\n",
    "           'hist': hist_loss}\n",
    "with open('./save/szo_curves.pickle', 'wb') as file:\n",
    "    pickle.dump(results, file)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "vscode": {
   "interpreter": {
    "hash": "20cfd2a0d09ae7bc83c97a517227e395fcdaac136dcb90e0b3ddfd33806f9ad5"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
