{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Membership Attack to Keras Model of MNIST Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-05-15 22:28:44.230932: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
      "2023-05-15 22:28:44.254364: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "2023-05-15 22:28:44.589584: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "from sklearn.model_selection import train_test_split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz\n",
      "11490434/11490434 [==============================] - 0s 0us/step\n",
      "Train samples: 60000\n",
      "Test samples: 10000\n"
     ]
    }
   ],
   "source": [
    "mnist = tf.keras.datasets.mnist\n",
    "(x_train, y_train), (x_test, y_test) = mnist.load_data()\n",
    "x_train, x_test = x_train/255, x_test/255  # normalize 0~1\n",
    "\n",
    "print('Train samples:', len(x_train))\n",
    "print('Test samples:', len(x_test))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we are going to split the data into 2 disjoints sets:\n",
    "+ One to train the target model\n",
    "+ The other to use as data from the same distribution to train our shadow models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Target Train samples: 30000\n",
      "Target Test samples: 5000\n",
      "Attack Train samples: 30000\n",
      "Attack Test samples: 5000\n"
     ]
    }
   ],
   "source": [
    "x_target_train, x_attack_train, y_target_train, y_attack_train = train_test_split(x_train, y_train, test_size=0.5)\n",
    "x_target_test, x_attack_test, y_target_tes, y_attack_test = train_test_split(x_test, y_test, test_size=0.5)\n",
    "\n",
    "print('Target Train samples:', len(x_target_train))\n",
    "print('Target Test samples:', len(x_target_test))\n",
    "print('Attack Train samples:', len(x_attack_train))\n",
    "print('Attack Test samples:', len(x_attack_test))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train the target model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "ename": "AttributeError",
     "evalue": "module 'tensorflow._api.v2.train' has no attribute 'AdamOptimizer'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[6], line 8\u001b[0m\n\u001b[1;32m      1\u001b[0m target_model \u001b[38;5;241m=\u001b[39m tf\u001b[38;5;241m.\u001b[39mkeras\u001b[38;5;241m.\u001b[39mmodels\u001b[38;5;241m.\u001b[39mSequential([\n\u001b[1;32m      2\u001b[0m     tf\u001b[38;5;241m.\u001b[39mkeras\u001b[38;5;241m.\u001b[39mlayers\u001b[38;5;241m.\u001b[39mFlatten(),\n\u001b[1;32m      3\u001b[0m     tf\u001b[38;5;241m.\u001b[39mkeras\u001b[38;5;241m.\u001b[39mlayers\u001b[38;5;241m.\u001b[39mDense(\u001b[38;5;241m512\u001b[39m, activation\u001b[38;5;241m=\u001b[39mtf\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39mrelu),\n\u001b[1;32m      4\u001b[0m     tf\u001b[38;5;241m.\u001b[39mkeras\u001b[38;5;241m.\u001b[39mlayers\u001b[38;5;241m.\u001b[39mDropout(\u001b[38;5;241m0.2\u001b[39m),\n\u001b[1;32m      5\u001b[0m     tf\u001b[38;5;241m.\u001b[39mkeras\u001b[38;5;241m.\u001b[39mlayers\u001b[38;5;241m.\u001b[39mDense(\u001b[38;5;241m10\u001b[39m, activation\u001b[38;5;241m=\u001b[39mtf\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39msoftmax)\n\u001b[1;32m      6\u001b[0m ])\n\u001b[0;32m----> 8\u001b[0m train_op \u001b[38;5;241m=\u001b[39m tf\u001b[38;5;241m.\u001b[39mtrain\u001b[38;5;241m.\u001b[39mAdamOptimizer(\u001b[38;5;241m1e-4\u001b[39m)\n\u001b[1;32m     10\u001b[0m target_model\u001b[38;5;241m.\u001b[39mcompile(optimizer\u001b[38;5;241m=\u001b[39mtrain_op,\n\u001b[1;32m     11\u001b[0m               loss\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124msparse_categorical_crossentropy\u001b[39m\u001b[38;5;124m'\u001b[39m,\n\u001b[1;32m     12\u001b[0m               metric\u001b[38;5;241m=\u001b[39m[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124maccuracy\u001b[39m\u001b[38;5;124m'\u001b[39m])\n",
      "\u001b[0;31mAttributeError\u001b[0m: module 'tensorflow._api.v2.train' has no attribute 'AdamOptimizer'"
     ]
    }
   ],
   "source": [
    "target_model = tf.keras.models.Sequential([\n",
    "    tf.keras.layers.Flatten(),\n",
    "    tf.keras.layers.Dense(512, activation=tf.nn.relu),\n",
    "    tf.keras.layers.Dropout(0.2),\n",
    "    tf.keras.layers.Dense(10, activation=tf.nn.softmax)\n",
    "])\n",
    "\n",
    "train_op = tf.train.AdamOptimizer(1e-4)\n",
    "\n",
    "target_model.compile(optimizer=train_op,\n",
    "              loss='sparse_categorical_crossentropy',\n",
    "              metric=['accuracy'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/5\n",
      "30000/30000 [==============================] - 2s 65us/step - loss: 0.6169\n",
      "Epoch 2/5\n",
      "30000/30000 [==============================] - 2s 63us/step - loss: 0.2840\n",
      "Epoch 3/5\n",
      "30000/30000 [==============================] - 2s 62us/step - loss: 0.2271\n",
      "Epoch 4/5\n",
      "30000/30000 [==============================] - 2s 62us/step - loss: 0.1921\n",
      "Epoch 5/5\n",
      "30000/30000 [==============================] - 2s 62us/step - loss: 0.1640\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7f6d0e885358>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "target_model.fit(x_target_train, y_target_train, epochs=5, use_multiprocessing=True, workers=6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "res = target_model.predict_proba(x_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model accuracy: 0.9538\n"
     ]
    }
   ],
   "source": [
    "print('Model accuracy:', sum(np.argmax(res, axis=1) == y_test) / len(y_test))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train the Shadow model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mblearn import ShadowModels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "shadow_model = tf.keras.models.Sequential([\n",
    "    tf.keras.layers.Flatten(),\n",
    "    tf.keras.layers.Dense(512, activation=tf.nn.relu),\n",
    "    tf.keras.layers.Dropout(0.2),\n",
    "    tf.keras.layers.Dense(10, activation=tf.nn.softmax)\n",
    "])\n",
    "\n",
    "train_op = tf.train.AdamOptimizer(1e-4)\n",
    "shadow_model.compile(optimizer=train_op,\n",
    "              loss='sparse_categorical_crossentropy',\n",
    "              metric=['accuracy'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "could not broadcast input array from shape (499,784) into shape (499)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-6-8067387840d1>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      5\u001b[0m                    \u001b[0mlearner\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mshadow_model\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      6\u001b[0m                    \u001b[0mepochs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m                    verbose=0)\n\u001b[0m",
      "\u001b[0;32m~/git-repos/membership_inference/mblearn/shadow_model.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, X, y, n_models, target_classes, learner, **fit_kwargs)\u001b[0m\n\u001b[1;32m     61\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     62\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtarget_classes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtarget_classes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 63\u001b[0;31m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplits\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_split_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mn_models\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtarget_classes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     64\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlearner\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlearner\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     65\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodels\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_make_model_list\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlearner\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mn_models\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/git-repos/membership_inference/mblearn/shadow_model.py\u001b[0m in \u001b[0;36m_split_data\u001b[0;34m(X, y, n_splits, n_classes)\u001b[0m\n\u001b[1;32m    107\u001b[0m         \u001b[0msplits\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    108\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mgroup\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mgrouped\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 109\u001b[0;31m             \u001b[0msplits\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvstack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgroup\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    110\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    111\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0msplits\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.conda/envs/mblearn/lib/python3.6/site-packages/numpy/core/shape_base.py\u001b[0m in \u001b[0;36mvstack\u001b[0;34m(tup)\u001b[0m\n\u001b[1;32m    232\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    233\u001b[0m     \"\"\"\n\u001b[0;32m--> 234\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0m_nx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0matleast_2d\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_m\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_m\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtup\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    235\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    236\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mhstack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtup\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.conda/envs/mblearn/lib/python3.6/site-packages/numpy/core/shape_base.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m    232\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    233\u001b[0m     \"\"\"\n\u001b[0;32m--> 234\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0m_nx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0matleast_2d\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_m\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_m\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtup\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    235\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    236\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mhstack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtup\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.conda/envs/mblearn/lib/python3.6/site-packages/numpy/core/shape_base.py\u001b[0m in \u001b[0;36matleast_2d\u001b[0;34m(*arys)\u001b[0m\n\u001b[1;32m    100\u001b[0m     \u001b[0mres\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    101\u001b[0m     \u001b[0;32mfor\u001b[0m \u001b[0mary\u001b[0m \u001b[0;32min\u001b[0m \u001b[0marys\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 102\u001b[0;31m         \u001b[0mary\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0masanyarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mary\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    103\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mary\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndim\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    104\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mary\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.conda/envs/mblearn/lib/python3.6/site-packages/numpy/core/numeric.py\u001b[0m in \u001b[0;36masanyarray\u001b[0;34m(a, dtype, order)\u001b[0m\n\u001b[1;32m    551\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    552\u001b[0m     \"\"\"\n\u001b[0;32m--> 553\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcopy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0morder\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0morder\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msubok\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    554\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    555\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mValueError\u001b[0m: could not broadcast input array from shape (499,784) into shape (499)"
     ]
    }
   ],
   "source": [
    "shm = ShadowModels(x_attack_train,\n",
    "                   y_attack_train,\n",
    "                   n_models=6,\n",
    "                   target_classes=10,\n",
    "                   learner=shadow_model,\n",
    "                   epochs=5,\n",
    "                   verbose=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7fb5657f1828>"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAADr5JREFUeJzt3X+Q1PV9x/HXm+MARVRQuCCSoISqSCLWk5BoO2msRg1TSEZpmIwlGZuLE2lrS384ZCZx2k7HaavWqYlTjESSsSZp1Yop00CYtDRtSjgJ8kNE0J4NeHDJYCtIgOPu3T/uS+aE288uu9/d7x7v52Pm5na/7++P9+zc6767+9n9fszdBSCeEUU3AKAYhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFAjG3mwUTbax2hsIw8JhHJEb+uYH7VK1q0p/GZ2s6SHJbVI+qq7359af4zG6gN2Qy2HBJCwwddVvG7VT/vNrEXSlyXdImmmpEVmNrPa/QForFpe88+RtNvdX3P3Y5K+KWl+Pm0BqLdawj9F0k8G3d+TLXsHM+sws04z6+zV0RoOByBPdX+3392Xu3u7u7e3anS9DwegQrWEf6+kqYPuX5wtAzAM1BL+jZJmmNklZjZK0iclrcqnLQD1VvVQn7sfN7Mlkr6rgaG+Fe6+PbfOANRVTeP87r5a0uqcegHQQHy8FwiK8ANBEX4gKMIPBEX4gaAIPxBUQ7/Pj+FnxLhxyfquL12ZrLdMPVyydtm7epLbThx9KFnfu2Rasu6d25L16DjzA0ERfiAowg8ERfiBoAg/EBThB4JiqC84uyY9VNf9ofOS9ZWfeCRZn1vHizd9/it9yXrXnPod+0zAmR8IivADQRF+ICjCDwRF+IGgCD8QFOEHgmKc/0wwoqVk6ZVHrklu+rWPfjVZv25Mb7Le6+mx9hnfu6tkrfV/0h8C6G9NltW2sT9ZH6sN6R0Ex5kfCIrwA0ERfiAowg8ERfiBoAg/EBThB4KqaZzfzLokHZTUJ+m4u7fn0RROkhjHl6RdD11bsrZ7/leS21614Y5k/azvnJusX/DiW8n6ZTteLlnrP1z6st6ovzw+5PNr7v6zHPYDoIF42g8EVWv4XdIaM3vBzDryaAhAY9T6tP96d99rZpMkrTWzl919/eAVsn8KHZI0RmfXeDgAeanpzO/ue7PfPZKelXTKJRPdfbm7t7t7e6vqeDVHAKel6vCb2VgzG3fitqSbJDEzIjBM1PK0v03Ss2Z2Yj9/7+7/kktXAOqu6vC7+2uSrsqxF5Sw+4HS4/iStOu20mP5V6z/THLb6Z/Zmaz3HzmSrHuyWr6O4jDUBwRF+IGgCD8QFOEHgiL8QFCEHwiKS3c3gf/71Nxk/aWFf5usr/35OSVr7/3z9FBdX5mhPJy5OPMDQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCM8zeBQ1PT/4MP96enyf6b228rWfPt26vqCWc+zvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBTj/A3QMnFisv7bd6xO1t/2/mTdf8xYPk4fZ34gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCKrsOL+ZrZA0T1KPu8/Klk2Q9C1J0yR1SVro7m/Wr83hbc9vzUjWf+f87ybr23v5OAbyV8mZ/wlJN5+07F5J69x9hqR12X0Aw0jZ8Lv7ekkHTlo8X9LK7PZKSQty7gtAnVX7mr/N3buz2/skteXUD4AGqfkNP3d3SV6qbmYdZtZpZp29Olrr4QDkpNrw7zezyZKU/e4ptaK7L3f3dndvb9XoKg8HIG/Vhn+VpMXZ7cWSnsunHQCNUjb8ZvaUpB9KuszM9pjZnZLul3Sjme2S9OvZfQDDSNkBZHdfVKJ0Q869DF8jWpLl82/qTtbLmb/6d5P1X9KPatp/UezqK9P1vr70Dl7pSpb735/4fMWPtqb3HQCf8AOCIvxAUIQfCIrwA0ERfiAowg8ExXdFc9D/ofcl69+ftaKm/U//9vGatq+nEe+/PFnf+Ydnl6w98yuPJrc9VubctOZg+nFfdF7p/c9b8cfJbd/9ZxuSdfWXGYYcBjjzA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQjPPn4PbH1tS0/eVP3p2sX7q+fl/ZbTn/vGT91aUzk/V/XfxXyfqkltLj/FJrcttyrrngpTJrlD72to5Hklt+ZNNdyfqY54fn16gH48wPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0Exzp+Dla/PTdbvfN8/Juvv+q/+9AFq+O74yKkXJ+uzVu1J1p+f9OUyR0iN49fmP46mz02H+9MzQN141s+rPvb+a9OXY3/P81Xvumlw5geCIvxAUIQfCIrwA0ERfiAowg8ERfiBoMqO85vZCknzJPW4+6xs2X2SPivpp9lqy9x9db2abHbjx1Q/nixJ3bcdTdanP5PefuSUi0pv++z+5LZ/MWlTeuc1+vze60rWXvuj9DX/W3sOJev9Z49K1t9+qvRg/IKx/5vcNoJKzvxPSLp5iOUPufvs7Cds8IHhqmz43X29pAMN6AVAA9Xymn+JmW0xsxVmNj63jgA0RLXhf1TSdEmzJXVLeqDUimbWYWadZtbZq/RrWwCNU1X43X2/u/e5e7+kxyTNSay73N3b3b29VekvYgBonKrCb2aTB939uKRt+bQDoFEqGep7StKHJV1oZnskfUnSh81stiSX1CXpc3XsEUAdlA2/uy8aYvHjdehl2Nr5Rlt6hRnp8uUXpcfiez79wWT9lt9fX7L2xQu3pg9exuLXP5Ks7/vCpcn6qM3/XbI24s0fJ7ctdxUDG5n+8/3aG6U/Y7Bgxj8ntx35tpU5+vDHJ/yAoAg/EBThB4Ii/EBQhB8IivADQZm7N+xg59oE/4Dd0LDjNcqIMWOS9e+8+p8N6uT0fXDzbybrF/5B+u+jb+fuPNt5h5b3XpKs9/7d8WR97RWlv9K75nB6evAHr7g6WffeY8l6UTb4Or3lByoap+TMDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBMUV3cH962apkfekn7iyzh0lVH/tIW3pq8kfnpb85fsNZ6cvC9SU+onDPE59Nbju1t3k/m5EXzvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBTf589BuUtId33x2mR92cJ/SNY/Na7ntHuK4Kinv88/a91dJWtXLNuX3Pb43jeq6qlofJ8fQFmEHwiK8ANBEX4gKMIPBEX4gaAIPxBU2XF+M5sq6euS2iS5pOXu/rCZTZD0LUnTJHVJWujub6b2daaO89dqz9NXJutb5n6jQZ2c6pXeI8n6qoNXJeu/Me7FkrWP/duS5LZLr12brD/8T/OS9UuW/TBZPxPlPc5/XNJSd58paa6ku81spqR7Ja1z9xmS1mX3AQwTZcPv7t3uvim7fVDSDklTJM2XtDJbbaWkBfVqEkD+Tus1v5lNk3S1pA2S2ty9Oyvt08DLAgDDRMXhN7NzJD0t6R53f2twzQfeOBjyzQMz6zCzTjPr7FX6mmsAGqei8JtZqwaC/6S7P5Mt3m9mk7P6ZElDfvvE3Ze7e7u7t7dqdB49A8hB2fCbmUl6XNIOd39wUGmVpMXZ7cWSnsu/PQD1UslQ3/WS/l3SVkknrrW8TAOv+78t6d2SXtfAUN+B1L4Y6htay/jxyXr/pRdVve9Xbz83Wf/YjRuT9RalL6+946Pp3v2iiSVr/VteTh+7LX1Z8L79fNX5ZKcz1Ff2uv3u/gNJpXZGkoFhik/4AUERfiAowg8ERfiBoAg/EBThB4Li0t3AGYRLdwMoi/ADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4IqG34zm2pm3zezl8xsu5n9Xrb8PjPba2abs59b698ugLyMrGCd45KWuvsmMxsn6QUzW5vVHnL3v65fewDqpWz43b1bUnd2+6CZ7ZA0pd6NAaiv03rNb2bTJF0taUO2aImZbTGzFWY2vsQ2HWbWaWadvTpaU7MA8lNx+M3sHElPS7rH3d+S9Kik6ZJma+CZwQNDbefuy9293d3bWzU6h5YB5KGi8JtZqwaC/6S7PyNJ7r7f3fvcvV/SY5Lm1K9NAHmr5N1+k/S4pB3u/uCg5ZMHrfZxSdvybw9AvVTybv91ku6QtNXMNmfLlklaZGazJbmkLkmfq0uHAOqiknf7fyBpqPm+V+ffDoBG4RN+QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoMzdG3cws59Ken3Qogsl/axhDZyeZu2tWfuS6K1aefb2HnefWMmKDQ3/KQc363T39sIaSGjW3pq1L4neqlVUbzztB4Ii/EBQRYd/ecHHT2nW3pq1L4neqlVIb4W+5gdQnKLP/AAKUkj4zexmM9tpZrvN7N4ieijFzLrMbGs283Bnwb2sMLMeM9s2aNkEM1trZruy30NOk1ZQb00xc3NiZulCH7tmm/G64U/7zaxF0iuSbpS0R9JGSYvc/aWGNlKCmXVJanf3wseEzexXJR2S9HV3n5Ut+0tJB9z9/uwf53h3/5Mm6e0+SYeKnrk5m1Bm8uCZpSUtkPRpFfjYJfpaqAIetyLO/HMk7Xb319z9mKRvSppfQB9Nz93XSzpw0uL5klZmt1dq4I+n4Ur01hTcvdvdN2W3D0o6MbN0oY9doq9CFBH+KZJ+Muj+HjXXlN8uaY2ZvWBmHUU3M4S2bNp0Sdonqa3IZoZQdubmRjppZummeeyqmfE6b7zhd6rr3f2XJd0i6e7s6W1T8oHXbM00XFPRzM2NMsTM0r9Q5GNX7YzXeSsi/HslTR10/+JsWVNw973Z7x5Jz6r5Zh/ef2KS1Ox3T8H9/EIzzdw81MzSaoLHrplmvC4i/BslzTCzS8xslKRPSlpVQB+nMLOx2RsxMrOxkm5S880+vErS4uz2YknPFdjLOzTLzM2lZpZWwY9d08147e4N/5F0qwbe8X9V0heK6KFEX5dKejH72V50b5Ke0sDTwF4NvDdyp6QLJK2TtEvS9yRNaKLeviFpq6QtGgja5IJ6u14DT+m3SNqc/dxa9GOX6KuQx41P+AFB8YYfEBThB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGg/h+m/o8Y6ixIiwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.imshow(x_attack_train[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7fb565beeef0>"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAADr5JREFUeJzt3X+Q1PV9x/HXm+MARVRQuCCSoISqSCLWk5BoO2msRg1TSEZpmIwlGZuLE2lrS384ZCZx2k7HaavWqYlTjESSsSZp1Yop00CYtDRtSjgJ8kNE0J4NeHDJYCtIgOPu3T/uS+aE288uu9/d7x7v52Pm5na/7++P9+zc6767+9n9fszdBSCeEUU3AKAYhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFAjG3mwUTbax2hsIw8JhHJEb+uYH7VK1q0p/GZ2s6SHJbVI+qq7359af4zG6gN2Qy2HBJCwwddVvG7VT/vNrEXSlyXdImmmpEVmNrPa/QForFpe88+RtNvdX3P3Y5K+KWl+Pm0BqLdawj9F0k8G3d+TLXsHM+sws04z6+zV0RoOByBPdX+3392Xu3u7u7e3anS9DwegQrWEf6+kqYPuX5wtAzAM1BL+jZJmmNklZjZK0iclrcqnLQD1VvVQn7sfN7Mlkr6rgaG+Fe6+PbfOANRVTeP87r5a0uqcegHQQHy8FwiK8ANBEX4gKMIPBEX4gaAIPxBUQ7/Pj+FnxLhxyfquL12ZrLdMPVyydtm7epLbThx9KFnfu2Rasu6d25L16DjzA0ERfiAowg8ERfiBoAg/EBThB4JiqC84uyY9VNf9ofOS9ZWfeCRZn1vHizd9/it9yXrXnPod+0zAmR8IivADQRF+ICjCDwRF+IGgCD8QFOEHgmKc/0wwoqVk6ZVHrklu+rWPfjVZv25Mb7Le6+mx9hnfu6tkrfV/0h8C6G9NltW2sT9ZH6sN6R0Ex5kfCIrwA0ERfiAowg8ERfiBoAg/EBThB4KqaZzfzLokHZTUJ+m4u7fn0RROkhjHl6RdD11bsrZ7/leS21614Y5k/azvnJusX/DiW8n6ZTteLlnrP1z6st6ovzw+5PNr7v6zHPYDoIF42g8EVWv4XdIaM3vBzDryaAhAY9T6tP96d99rZpMkrTWzl919/eAVsn8KHZI0RmfXeDgAeanpzO/ue7PfPZKelXTKJRPdfbm7t7t7e6vqeDVHAKel6vCb2VgzG3fitqSbJDEzIjBM1PK0v03Ss2Z2Yj9/7+7/kktXAOqu6vC7+2uSrsqxF5Sw+4HS4/iStOu20mP5V6z/THLb6Z/Zmaz3HzmSrHuyWr6O4jDUBwRF+IGgCD8QFOEHgiL8QFCEHwiKS3c3gf/71Nxk/aWFf5usr/35OSVr7/3z9FBdX5mhPJy5OPMDQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCM8zeBQ1PT/4MP96enyf6b228rWfPt26vqCWc+zvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBTj/A3QMnFisv7bd6xO1t/2/mTdf8xYPk4fZ34gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCKrsOL+ZrZA0T1KPu8/Klk2Q9C1J0yR1SVro7m/Wr83hbc9vzUjWf+f87ybr23v5OAbyV8mZ/wlJN5+07F5J69x9hqR12X0Aw0jZ8Lv7ekkHTlo8X9LK7PZKSQty7gtAnVX7mr/N3buz2/skteXUD4AGqfkNP3d3SV6qbmYdZtZpZp29Olrr4QDkpNrw7zezyZKU/e4ptaK7L3f3dndvb9XoKg8HIG/Vhn+VpMXZ7cWSnsunHQCNUjb8ZvaUpB9KuszM9pjZnZLul3Sjme2S9OvZfQDDSNkBZHdfVKJ0Q869DF8jWpLl82/qTtbLmb/6d5P1X9KPatp/UezqK9P1vr70Dl7pSpb735/4fMWPtqb3HQCf8AOCIvxAUIQfCIrwA0ERfiAowg8ExXdFc9D/ofcl69+ftaKm/U//9vGatq+nEe+/PFnf+Ydnl6w98yuPJrc9VubctOZg+nFfdF7p/c9b8cfJbd/9ZxuSdfWXGYYcBjjzA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQjPPn4PbH1tS0/eVP3p2sX7q+fl/ZbTn/vGT91aUzk/V/XfxXyfqkltLj/FJrcttyrrngpTJrlD72to5Hklt+ZNNdyfqY54fn16gH48wPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0Exzp+Dla/PTdbvfN8/Juvv+q/+9AFq+O74yKkXJ+uzVu1J1p+f9OUyR0iN49fmP46mz02H+9MzQN141s+rPvb+a9OXY3/P81Xvumlw5geCIvxAUIQfCIrwA0ERfiAowg8ERfiBoMqO85vZCknzJPW4+6xs2X2SPivpp9lqy9x9db2abHbjx1Q/nixJ3bcdTdanP5PefuSUi0pv++z+5LZ/MWlTeuc1+vze60rWXvuj9DX/W3sOJev9Z49K1t9+qvRg/IKx/5vcNoJKzvxPSLp5iOUPufvs7Cds8IHhqmz43X29pAMN6AVAA9Xymn+JmW0xsxVmNj63jgA0RLXhf1TSdEmzJXVLeqDUimbWYWadZtbZq/RrWwCNU1X43X2/u/e5e7+kxyTNSay73N3b3b29VekvYgBonKrCb2aTB939uKRt+bQDoFEqGep7StKHJV1oZnskfUnSh81stiSX1CXpc3XsEUAdlA2/uy8aYvHjdehl2Nr5Rlt6hRnp8uUXpcfiez79wWT9lt9fX7L2xQu3pg9exuLXP5Ks7/vCpcn6qM3/XbI24s0fJ7ctdxUDG5n+8/3aG6U/Y7Bgxj8ntx35tpU5+vDHJ/yAoAg/EBThB4Ii/EBQhB8IivADQZm7N+xg59oE/4Dd0LDjNcqIMWOS9e+8+p8N6uT0fXDzbybrF/5B+u+jb+fuPNt5h5b3XpKs9/7d8WR97RWlv9K75nB6evAHr7g6WffeY8l6UTb4Or3lByoap+TMDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBMUV3cH962apkfekn7iyzh0lVH/tIW3pq8kfnpb85fsNZ6cvC9SU+onDPE59Nbju1t3k/m5EXzvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBTf589BuUtId33x2mR92cJ/SNY/Na7ntHuK4Kinv88/a91dJWtXLNuX3Pb43jeq6qlofJ8fQFmEHwiK8ANBEX4gKMIPBEX4gaAIPxBU2XF+M5sq6euS2iS5pOXu/rCZTZD0LUnTJHVJWujub6b2daaO89dqz9NXJutb5n6jQZ2c6pXeI8n6qoNXJeu/Me7FkrWP/duS5LZLr12brD/8T/OS9UuW/TBZPxPlPc5/XNJSd58paa6ku81spqR7Ja1z9xmS1mX3AQwTZcPv7t3uvim7fVDSDklTJM2XtDJbbaWkBfVqEkD+Tus1v5lNk3S1pA2S2ty9Oyvt08DLAgDDRMXhN7NzJD0t6R53f2twzQfeOBjyzQMz6zCzTjPr7FX6mmsAGqei8JtZqwaC/6S7P5Mt3m9mk7P6ZElDfvvE3Ze7e7u7t7dqdB49A8hB2fCbmUl6XNIOd39wUGmVpMXZ7cWSnsu/PQD1UslQ3/WS/l3SVkknrrW8TAOv+78t6d2SXtfAUN+B1L4Y6htay/jxyXr/pRdVve9Xbz83Wf/YjRuT9RalL6+946Pp3v2iiSVr/VteTh+7LX1Z8L79fNX5ZKcz1Ff2uv3u/gNJpXZGkoFhik/4AUERfiAowg8ERfiBoAg/EBThB4Li0t3AGYRLdwMoi/ADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4IqG34zm2pm3zezl8xsu5n9Xrb8PjPba2abs59b698ugLyMrGCd45KWuvsmMxsn6QUzW5vVHnL3v65fewDqpWz43b1bUnd2+6CZ7ZA0pd6NAaiv03rNb2bTJF0taUO2aImZbTGzFWY2vsQ2HWbWaWadvTpaU7MA8lNx+M3sHElPS7rH3d+S9Kik6ZJma+CZwQNDbefuy9293d3bWzU6h5YB5KGi8JtZqwaC/6S7PyNJ7r7f3fvcvV/SY5Lm1K9NAHmr5N1+k/S4pB3u/uCg5ZMHrfZxSdvybw9AvVTybv91ku6QtNXMNmfLlklaZGazJbmkLkmfq0uHAOqiknf7fyBpqPm+V+ffDoBG4RN+QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoMzdG3cws59Ken3Qogsl/axhDZyeZu2tWfuS6K1aefb2HnefWMmKDQ3/KQc363T39sIaSGjW3pq1L4neqlVUbzztB4Ii/EBQRYd/ecHHT2nW3pq1L4neqlVIb4W+5gdQnKLP/AAKUkj4zexmM9tpZrvN7N4ieijFzLrMbGs283Bnwb2sMLMeM9s2aNkEM1trZruy30NOk1ZQb00xc3NiZulCH7tmm/G64U/7zaxF0iuSbpS0R9JGSYvc/aWGNlKCmXVJanf3wseEzexXJR2S9HV3n5Ut+0tJB9z9/uwf53h3/5Mm6e0+SYeKnrk5m1Bm8uCZpSUtkPRpFfjYJfpaqAIetyLO/HMk7Xb319z9mKRvSppfQB9Nz93XSzpw0uL5klZmt1dq4I+n4Ur01hTcvdvdN2W3D0o6MbN0oY9doq9CFBH+KZJ+Muj+HjXXlN8uaY2ZvWBmHUU3M4S2bNp0Sdonqa3IZoZQdubmRjppZummeeyqmfE6b7zhd6rr3f2XJd0i6e7s6W1T8oHXbM00XFPRzM2NMsTM0r9Q5GNX7YzXeSsi/HslTR10/+JsWVNw973Z7x5Jz6r5Zh/ef2KS1Ox3T8H9/EIzzdw81MzSaoLHrplmvC4i/BslzTCzS8xslKRPSlpVQB+nMLOx2RsxMrOxkm5S880+vErS4uz2YknPFdjLOzTLzM2lZpZWwY9d08147e4N/5F0qwbe8X9V0heK6KFEX5dKejH72V50b5Ke0sDTwF4NvDdyp6QLJK2TtEvS9yRNaKLeviFpq6QtGgja5IJ6u14DT+m3SNqc/dxa9GOX6KuQx41P+AFB8YYfEBThB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGg/h+m/o8Y6ixIiwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.imshow(x_attack_train.reshape(x_attack_train.shape[0],-1)[0].reshape(28,28))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_attack_train[0,:].ndim"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train Attack models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mblearn import AttackModels\n",
    "from sklearn.ensemble import RandomForestClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "rf_attack = RandomForestClassifier(n_estimators=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "attacker = AttackModels(target_classes=10, attack_learner=rf_attack)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "Found array with 0 sample(s) (shape=(0, 10)) while a minimum of 1 is required.",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-24-01164530c97e>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mattacker\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mshm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlearner_kwargs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/git-repos/membership_inference/mblearn/attack_model.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, shadow_data, learner_kwargs)\u001b[0m\n\u001b[1;32m     95\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_update_learner_params\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlearner_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     96\u001b[0m             \u001b[0;31m# train model\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 97\u001b[0;31m             \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     98\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     99\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_fited\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.conda/envs/mblearn/lib/python3.6/site-packages/sklearn/ensemble/forest.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X, y, sample_weight)\u001b[0m\n\u001b[1;32m    248\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    249\u001b[0m         \u001b[0;31m# Validate or convert input data\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 250\u001b[0;31m         \u001b[0mX\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcheck_array\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maccept_sparse\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"csc\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mDTYPE\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    251\u001b[0m         \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcheck_array\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maccept_sparse\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'csc'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mensure_2d\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    252\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0msample_weight\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.conda/envs/mblearn/lib/python3.6/site-packages/sklearn/utils/validation.py\u001b[0m in \u001b[0;36mcheck_array\u001b[0;34m(array, accept_sparse, accept_large_sparse, dtype, order, copy, force_all_finite, ensure_2d, allow_nd, ensure_min_samples, ensure_min_features, warn_on_dtype, estimator)\u001b[0m\n\u001b[1;32m    580\u001b[0m                              \u001b[0;34m\" minimum of %d is required%s.\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    581\u001b[0m                              % (n_samples, shape_repr, ensure_min_samples,\n\u001b[0;32m--> 582\u001b[0;31m                                 context))\n\u001b[0m\u001b[1;32m    583\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    584\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mensure_min_features\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0marray\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndim\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mValueError\u001b[0m: Found array with 0 sample(s) (shape=(0, 10)) while a minimum of 1 is required."
     ]
    }
   ],
   "source": [
    "attacker.fit(shm.results, learner_kwargs={})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
