{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4db97a4a-9313-4c36-bddc-3cffe732e156",
   "metadata": {
    "pycharm": {
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
    "import tensorflow as tf\n",
    "from time import strftime\n",
    "from os import path\n",
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "my_image_size = (28,28)\n",
    "my_input_shape = my_image_size + (1,)\n",
    "my_train_epochs = 100\n",
    "my_batch = 32\n",
    "my_shuffle_buffer_size = 1000\n",
    "\n",
    "learning_rate_fn = 1e-4\n",
    "\n",
    "train_datagen = ImageDataGenerator(rescale=1/255)\n",
    "validation_datagen = ImageDataGenerator(rescale=1/255,validation_split = 0.05)\n",
    "\n",
    "train_generator = train_datagen.flow_from_directory(\n",
    "    'file/counterfactual_adversarial_examples',\n",
    "    target_size=my_image_size,\n",
    "    batch_size=my_batch,\n",
    "    color_mode=  'rgb' if my_input_shape[2]==3 else 'grayscale',\n",
    "    class_mode='categorical')\n",
    "\n",
    "\n",
    "validation_generator = validation_datagen.flow_from_directory(\n",
    "     'file/mnist/train',\n",
    "     target_size=my_image_size,\n",
    "     batch_size=my_batch,\n",
    "     color_mode=  'rgb' if my_input_shape[2]==3 else 'grayscale',\n",
    "     class_mode='categorical',\n",
    "     subset='validation')\n",
    "\n",
    "if __name__ == '__main__':\n",
    "\n",
    "    input_data = tf.keras.layers.Input(shape=my_input_shape)\n",
    "\n",
    "    middle = tf.keras.layers.Conv2D(32, kernel_size=[2,2], strides=(1,1), padding='same', activation=tf.nn.relu)(input_data)\n",
    "    middle = tf.keras.layers.MaxPool2D(pool_size=[2,2], strides=2, padding='same')(middle)\n",
    "    middle = tf.keras.layers.Dropout(0.1)(middle)\n",
    "\n",
    "    middle = tf.keras.layers.Conv2D(64, kernel_size=[2,2], strides=(1,1), padding='same', activation=tf.nn.relu)(middle)\n",
    "    middle = tf.keras.layers.MaxPool2D(pool_size=[2,2], strides=2, padding='same')(middle)\n",
    "    middle = tf.keras.layers.Dropout(0.1)(middle)\n",
    "\n",
    "    dense = tf.keras.layers.Flatten()(middle)\n",
    "    dense = tf.keras.layers.Dense(256, activation='relu')(dense)\n",
    "    dense = tf.keras.layers.Dropout(0.1)(dense)\n",
    "\n",
    "    output_data = tf.keras.layers.Dense(10, activation='softmax')(dense)\n",
    "\n",
    "    model = tf.keras.Model(inputs=input_data, outputs=output_data)\n",
    "\n",
    "\n",
    "    model.compile(optimizer=tf.optimizers.Adam(learning_rate_fn),\n",
    "                loss=tf.losses.categorical_crossentropy,\n",
    "                metrics=['accuracy'])\n",
    "    \n",
    "\n",
    "    model.summary()\n",
    "\n",
    "    checkpoint = tf.keras.callbacks.ModelCheckpoint(os.path.join('model_{epoch:03d}.hdf5'), \n",
    "    monitor='val_loss', verbose=1,\n",
    "    save_best_only=False, save_weights_only=False, period=10)\n",
    "\n",
    "    start_time = strftime(\"%Y-%m-%d %H:%M:%S\")\n",
    "    history = model.fit(\n",
    "        train_generator,\n",
    "        epochs=my_train_epochs,\n",
    "        verbose=1,\n",
    "        callbacks=[checkpoint],\n",
    "        validation_data=validation_generator\n",
    "        )\n",
    "\n",
    "    plt.plot(history.history['accuracy'])\n",
    "    plt.plot(history.history['val_accuracy'])\n",
    "    plt.title('Model accuracy')\n",
    "    plt.ylabel('Accuracy')\n",
    "    plt.xlabel('Epoch')\n",
    "    plt.legend(['Train', 'Test'], loc='upper left')\n",
    "    plt.show()\n",
    "\n",
    "    plt.plot(history.history['loss'])\n",
    "    plt.plot(history.history['val_loss'])\n",
    "    plt.title('Model loss')\n",
    "    plt.ylabel('Loss')\n",
    "    plt.xlabel('Epoch')\n",
    "    plt.legend(['Train', 'Test'], loc='upper left')\n",
    "    plt.show()\n",
    "\n",
    "    end_time = strftime(\"%Y-%m-%d %H:%M:%S\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60f82d72-0680-4c65-b7f4-4923af74e4da",
   "metadata": {
    "pycharm": {
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "from tensorflow.keras.preprocessing import image\n",
    "import os\n",
    "import numpy as np\n",
    "\n",
    "def get_inputs(src=[]):\n",
    "    pre_x = []\n",
    "    for path in src:\n",
    "        img = image.load_img(path, color_mode = 'grayscale', target_size=(28, 28))\n",
    "        img_tensor = image.img_to_array(img)\n",
    "        pre_x.append(img_tensor)\n",
    "    pre_x = np.array(pre_x) / 255.0\n",
    "    return pre_x\n",
    "\n",
    "def images_path_list(predict_dir):\n",
    "    images = []\n",
    "    labels = []\n",
    "    label_map = {'9': 9, '1': 1, '6': 6, '5': 5, '2': 2, '0': 0, '3': 3, '4': 4, '7': 7, '8': 8}\n",
    "    for testpath in os.listdir(predict_dir):\n",
    "        if os.path.isdir(os.path.join(predict_dir, testpath)):\n",
    "            for fn in os.listdir(os.path.join(predict_dir, testpath)):\n",
    "                if fn.endswith('png'):\n",
    "                    fd = os.path.join(predict_dir, testpath, fn)\n",
    "                    images.append(fd)\n",
    "                    labels.append(label_map[testpath])\n",
    "    return images,labels\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    checkpoint_save_path = 'model_100.hdf5'\n",
    "    print('------load the model------')\n",
    "    model.load_weights(checkpoint_save_path) \n",
    "    train_path = 'file/mnist/train'\n",
    "    pre_train = get_inputs(images_path_list(train_path)[0])\n",
    "    output_train_adversarial = model.predict(pre_train, batch_size=32)\n",
    "    train_label = images_path_list(train_path)[1]\n",
    "    train_label = np.asarray(train_label)\n",
    "\n",
    "    validation_datagen = ImageDataGenerator(rescale=1/255)\n",
    "    val_generator = validation_datagen.flow_from_directory(\n",
    "        'file/mnist/test',\n",
    "        target_size=my_image_size,\n",
    "        batch_size=my_batch,\n",
    "        color_mode=  'rgb' if my_input_shape[2]==3 else 'grayscale',\n",
    "        class_mode='categorical')\n",
    "    scores = model.evaluate_generator(generator=val_generator,\n",
    "                                  verbose=1)\n",
    "    print('%s: %.2f' % (model.metrics_names[0], scores[0]))\n",
    "    print('%s: %.2f%%' % (model.metrics_names[1], scores[1] * 100)) \n",
    "\n",
    "    test_path = 'file/mnist/test'\n",
    "    pre_test = get_inputs(images_path_list(test_path)[0])\n",
    "    output_test_adversarial = model.predict(pre_test, batch_size=32)\n",
    "    test_label = images_path_list(test_path)[1]\n",
    "    test_label = np.asarray(test_label)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59cde60f-a944-4cc2-a3de-4f77cec3a94c",
   "metadata": {
    "pycharm": {
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "np.savez('counterfactual_adversarial_example.npz',output_train_adversarial = output_train_adversarial,\n",
    "         output_test_adversarial = output_test_adversarial, train_label = train_label, test_label = test_label)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
