{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "843d4b63-dcc7-49f6-9e10-8c35c2cbe561",
   "metadata": {
    "pycharm": {
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "tf.get_logger().setLevel(40)\n",
    "tf.compat.v1.disable_v2_behavior()\n",
    "tf.keras.backend.clear_session()\n",
    "from tensorflow.keras.layers import Conv2D, Dense, Dropout, Flatten, MaxPooling2D, Input, UpSampling2D\n",
    "from tensorflow.keras.models import Model, load_model\n",
    "from tensorflow.keras.utils import to_categorical\n",
    "import matplotlib\n",
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import os\n",
    "from time import time\n",
    "from alibi.explainers import CounterfactualProto\n",
    "\n",
    "print('TF version: ', tf.__version__)\n",
    "print('Eager execution enabled: ', tf.executing_eagerly())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c97e036f-0a5d-4910-852b-8fd2476f77ce",
   "metadata": {
    "pycharm": {
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()\n",
    "print('x_train shape:', x_train.shape, 'y_train shape:', y_train.shape)\n",
    "plt.gray()\n",
    "plt.imshow(x_test[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa202f77-a5c4-47bf-9ee3-648a98c316c7",
   "metadata": {
    "pycharm": {
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "x_train = x_train.astype('float32') / 255\n",
    "x_test = x_test.astype('float32') / 255\n",
    "x_train = np.reshape(x_train, x_train.shape + (1,))\n",
    "x_test = np.reshape(x_test, x_test.shape + (1,))\n",
    "print('x_train shape:', x_train.shape, 'x_test shape:', x_test.shape)\n",
    "y_train = to_categorical(y_train)\n",
    "y_test = to_categorical(y_test)\n",
    "print('y_train shape:', y_train.shape, 'y_test shape:', y_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "12d571d5-aaf2-45f1-8050-32e61e74f0de",
   "metadata": {},
   "outputs": [],
   "source": [
    "xmin, xmax = -.5, .5\n",
    "x_train = ((x_train - x_train.min()) / (x_train.max() - x_train.min())) * (xmax - xmin) + xmin\n",
    "x_test = ((x_test - x_test.min()) / (x_test.max() - x_test.min())) * (xmax - xmin) + xmin"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e43266d-5ff6-43ad-ab3e-47daf25eef4e",
   "metadata": {
    "pycharm": {
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "cnn = load_model('mnist_cnn.h5')\n",
    "score = cnn.evaluate(x_test, y_test, verbose=0)\n",
    "print('Test accuracy: ', score[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d018102f-e511-48d2-8e00-a4e2f5e77598",
   "metadata": {},
   "outputs": [],
   "source": [
    "ae = load_model('mnist_ae.h5')\n",
    "enc = load_model('mnist_enc.h5', compile=False)\n",
    "\n",
    "decoded_imgs = ae.predict(x_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a6e2761b-23b9-470c-a71c-7088b6ae2c7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "shape = (1,) + x_train.shape[1:]\n",
    "gamma = 100.\n",
    "theta = 100.\n",
    "c_init = 1.\n",
    "c_steps = 2\n",
    "max_iterations = 1000\n",
    "feature_range = (x_train.min(),x_train.max())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ddcef732-0495-4283-8585-fb1ee618080f",
   "metadata": {
    "tags": [],
    "pycharm": {
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "save_dir = \"file/counterfactual_adversarial_examples\"\n",
    "\n",
    "if not os.path.exists(save_dir):\n",
    "    os.makedirs(save_dir)\n",
    "\n",
    "cf = CounterfactualProto(cnn, shape, gamma=gamma, theta=theta,\n",
    "                             ae_model=ae, enc_model=enc, max_iterations=max_iterations,\n",
    "                             feature_range=feature_range, c_init=c_init, c_steps=c_steps)\n",
    "cf.fit(x_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a180472b-42ec-4253-9afe-ceefcc715cf7",
   "metadata": {
    "pycharm": {
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "enc_data = enc.predict(x_train)\n",
    "enc_label_dict = {}\n",
    "for i in range(len(x_train)):\n",
    "    if np.argmax(y_train[i]) == 0:\n",
    "        enc_label_dict.setdefault(0,[]).append(enc_data[i])\n",
    "    elif np.argmax(y_train[i]) == 1:\n",
    "        enc_label_dict.setdefault(1,[]).append(enc_data[i])\n",
    "    elif np.argmax(y_train[i]) == 2:\n",
    "        enc_label_dict.setdefault(2,[]).append(enc_data[i])\n",
    "    elif np.argmax(y_train[i]) == 3:\n",
    "        enc_label_dict.setdefault(3,[]).append(enc_data[i])\n",
    "    elif np.argmax(y_train[i]) == 4:\n",
    "        enc_label_dict.setdefault(4,[]).append(enc_data[i])\n",
    "    elif np.argmax(y_train[i]) == 5:\n",
    "        enc_label_dict.setdefault(5,[]).append(enc_data[i])\n",
    "    elif np.argmax(y_train[i]) == 6:\n",
    "        enc_label_dict.setdefault(6,[]).append(enc_data[i])\n",
    "    elif np.argmax(y_train[i]) == 7:\n",
    "        enc_label_dict.setdefault(7,[]).append(enc_data[i])\n",
    "    elif np.argmax(y_train[i]) == 8:\n",
    "        enc_label_dict.setdefault(8,[]).append(enc_data[i])\n",
    "    elif np.argmax(y_train[i]) == 9:\n",
    "        enc_label_dict.setdefault(9,[]).append(enc_data[i])\n",
    "\n",
    "\n",
    "class_proto = cf.class_proto\n",
    "target_label = -1\n",
    "x_neighbor = np.ndarray(x_train.shape)\n",
    "\n",
    "def nearest_neighbor(current):\n",
    "    X = x_train[current]\n",
    "    X.shape = (1,) + x_train.shape[1:]\n",
    "\n",
    "    enc_X = enc.predict(X)\n",
    "    class_proto_array = np.empty((10,14,14,1))\n",
    "\n",
    "    for proto in class_proto.values(): \n",
    "        np.append(class_proto_array, proto)\n",
    "\n",
    "    dist_proto = np.linalg.norm(enc_X.reshape(enc_X.shape[0], -1) - \n",
    "                                class_proto_array.reshape(class_proto_array.shape[0], -1), axis=1)\n",
    "\n",
    "    if np.argmax(y_train[current]) == np.argmin(dist_proto):\n",
    "            target_label = np.argsort(dist_proto)[1]        \n",
    "    else:\n",
    "            target_label = np.argmin(dist_proto)\n",
    "\n",
    "    target_label_data = np.asarray(enc_label_dict.get(target_label))\n",
    "    dist = np.linalg.norm(enc_X.reshape(enc_X.shape[0], -1) - \n",
    "                                target_label_data.reshape(target_label_data.shape[0], -1), axis=1)\n",
    "\n",
    "    index = np.argmin(dist)\n",
    "    return index\n",
    "\n",
    "\n",
    "save_dir = \"file/neighbor_dataset\"\n",
    "\n",
    "if not os.path.exists(save_dir):\n",
    "    os.makedirs(save_dir)\n",
    "\n",
    "for i in range(len(x_train)):\n",
    "    x_neighbor[i] = x_train[nearest_neighbor(i)]\n",
    "\n",
    "    label = np.argmax(y_train[i])\n",
    "\n",
    "    if (i + 1) % 1000 == 0:\n",
    "        print(f'Generated {i + 1} neighbor images.')\n",
    "\n",
    "np.save(\"neighbor_dataset.npy\",x_neighbor)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b05b915-b796-4345-bfd1-6faf9d4ffc77",
   "metadata": {
    "pycharm": {
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "x_train = x_neighbor\n",
    "\n",
    "save_dir = \"file/counterfactual_adversarial_examples\"\n",
    "\n",
    "if not os.path.exists(save_dir):\n",
    "    os.makedirs(save_dir)\n",
    "\n",
    "for i in range(len(x_train)):\n",
    "\n",
    "    current_sample = x_train[i].reshape((1,) + x_train[i].shape)\n",
    "    start_time = time()\n",
    "    target_label = np.argmax(y_train[i])\n",
    "    \n",
    "    explanation = cf.explain(current_sample, target_class=[target_label], verbose=False)\n",
    "    print('Explanation took {:.3f} sec'.format(time() - start_time))\n",
    "    \n",
    "    if explanation.cf is None:\n",
    "        continue\n",
    "\n",
    "    reconstructed_img = explanation.cf['X'].reshape(28, 28)\n",
    "\n",
    "    save_path = os.path.join(save_dir, f'explanation_{i}_target_label_{target_label}.png')\n",
    "    plt.imsave(save_path, reconstructed_img, cmap='gray')\n",
    "\n",
    "    if (i + 1) % 10 == 0:\n",
    "        print(f'Saved {i + 1} explanations images.')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
