{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "# Keras VGG16 for CIFAR10 test\n",
    "keras vgg16 CIFAR10 for cifar10 test."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-30T17:48:45.852220Z",
     "start_time": "2023-04-30T17:48:43.565318Z"
    }
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras.models import Sequential\n",
    "from tensorflow.keras.layers import Dense, Activation\n",
    "from data_utils import *\n",
    "\n",
    "%matplotlib inline\n",
    "plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots\n",
    "plt.rcParams['image.interpolation'] = 'nearest'\n",
    "plt.rcParams['image.cmap'] = 'gray'\n",
    "\n",
    "import tensorflow as tf \n",
    "from tensorflow.keras import backend as k\n",
    "import os\n",
    "from tensorflow.compat.v1 import ConfigProto\n",
    "config = ConfigProto()\n",
    "# config.gpu_options.per_process_gpu_memory_fraction = 0.1\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '/gpu:0,1'\n",
    "\n",
    "config=tf.compat.v1.ConfigProto()\n",
    "config.gpu_options.allow_growth=True\n",
    "sess=tf.compat.v1.Session(config=config) \n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# get data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-30T17:48:50.717282Z",
     "start_time": "2023-04-30T17:48:45.856000Z"
    }
   },
   "outputs": [],
   "source": [
    "# get data\n",
    "cifar10_data = CIFAR10Data()\n",
    "\n",
    "x_train_o, y_train_o, x_test_o, y_test_o = cifar10_data.get_data(subtract_mean=True)\n",
    "\n",
    "all_index_train = np.where((np.argmax(y_train_o,axis=1)==0 ) | (np.argmax(y_train_o,axis=1)==1))[0]\n",
    "all_index_test = np.where((np.argmax(y_test_o,axis=1)==0 ) | (np.argmax(y_test_o,axis=1)==1))[0]\n",
    "\n",
    "np.random.seed(1234)\n",
    "LS_index = all_index_train[np.random.randint(0,len(all_index_train),500)]\n",
    "Separability_images,Separability_labels = x_train_o[LS_index,:], y_train_o[LS_index,]\n",
    "\n",
    "x_train,y_train = x_train_o[all_index_train,:],y_train_o[all_index_train,:]\n",
    "x_test,y_test = x_test_o[all_index_test,:],y_test_o[all_index_test,:]\n",
    "\n",
    "# ##########################################\n",
    "# y_train = np.argmax(y_train,axis=1)\n",
    "# y_test = np.argmax(y_test,axis=1)\n",
    "\n",
    "# from tensorflow.keras.utils import to_categorical\n",
    "# y_train = to_categorical(y_train, num_classes=2)\n",
    "# y_test = to_categorical(y_test, num_classes=2)\n",
    "# ##########################################\n",
    "\n",
    "num_train = int(x_train.shape[0] * 0.9)\n",
    "num_val = x_train.shape[0] - num_train\n",
    "mask = list(range(num_train, num_train+num_val))\n",
    "x_val = x_train[mask]\n",
    "y_val = y_train[mask]\n",
    "\n",
    "mask = list(range(num_train))\n",
    "x_train = x_train[mask]\n",
    "y_train = y_train[mask]\n",
    "\n",
    "print('num train:%d num val:%d' % (num_train, num_val))\n",
    "data = (x_train, y_train, x_val, y_val, x_test, y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-30T17:48:51.492226Z",
     "start_time": "2023-04-30T17:48:50.721226Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from classifiers.vgg import VGGNet \n",
    "from tensorflow.keras import losses\n",
    "from tensorflow.keras import optimizers\n",
    "\n",
    "\n",
    "weight_decay = 5e-4\n",
    "lr = 1e-3\n",
    "num_classes = 10\n",
    "\n",
    "vgg = VGGNet(classes=num_classes, \n",
    "             input_shape=x_train.shape[1:], \n",
    "             weight_decay=weight_decay, \n",
    "             conv_block_num=4,\n",
    "             fc_layers=2,\n",
    "             fc_units=512\n",
    "             ) \n",
    "\n",
    "# sgd\n",
    "opt = keras.optimizers.SGD(lr=lr, momentum=0.9, nesterov=False)\n",
    "vgg.compile(loss='categorical_crossentropy', \n",
    "            optimizer=opt, \n",
    "            metrics=['accuracy'])\n",
    "vgg.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-30T17:48:51.545496Z",
     "start_time": "2023-04-30T17:48:51.496949Z"
    }
   },
   "outputs": [],
   "source": [
    "layer_outputs = [\n",
    "    layer.output for layer in vgg.layers if 'input' not in layer.name\n",
    "]\n",
    "activation_model = tf.keras.models.Model(inputs=vgg.input,\n",
    "                                         outputs=layer_outputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-30T17:48:54.571707Z",
     "start_time": "2023-04-30T17:48:51.548497Z"
    }
   },
   "outputs": [],
   "source": [
    "from Linear_Separability_numpy import *\n",
    "\n",
    "# parameters (train)\n",
    "num_epochs = 100\n",
    "\n",
    "batch_ep = 1\n",
    "\n",
    "x_plot = np.arange(num_epochs)*batch_ep\n",
    "reserved_layers = 0\n",
    "\n",
    "\n",
    "\n",
    "# initialize record matrix\n",
    "x=activation_model.predict(Separability_images)\n",
    "LS_1_squence = np.zeros((len(x)-reserved_layers,num_epochs))\n",
    "LS_2_squence = np.zeros((len(x)-reserved_layers,num_epochs))\n",
    "J_w_squence = np.zeros((len(x)-reserved_layers,num_epochs))\n",
    "LDA_squence = np.zeros((len(x)-reserved_layers,num_epochs))\n",
    "LS_1_squence_base = np.zeros((1,num_epochs))\n",
    "LS_2_squence_base = np.zeros((1,num_epochs))\n",
    "J_w_squence_base = np.zeros((1,num_epochs))\n",
    "LDA_squence_base = np.zeros((1,num_epochs))\n",
    "train_loss_squence = np.zeros((num_epochs,))\n",
    "train_accuracy_squence = np.zeros((num_epochs,))\n",
    "test_loss_squence = np.zeros((num_epochs,))\n",
    "test_accuracy_squence = np.zeros((num_epochs,))\n",
    "\n",
    "Separability_labels = np.argmax(Separability_labels,axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-30T17:49:00.201835Z",
     "start_time": "2023-04-30T17:48:54.576114Z"
    }
   },
   "outputs": [],
   "source": [
    "LS_1_squence_base[0,:],LS_2_squence_base[0,:],\\\n",
    "J_w_squence_base[0,:],\\\n",
    "LDA_squence_base[0,:]=W(tf.constant(Separability_images/ 255.0),Separability_labels.reshape(-1,1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2023-04-30T17:48:43.339Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
    "from tensorflow.keras.callbacks import ReduceLROnPlateau\n",
    "from tensorflow.keras.callbacks import LearningRateScheduler\n",
    "# fit data with data augmentation or not\n",
    "data_augmentation = True\n",
    "\n",
    "\n",
    "\n",
    "base_lr = 1e-3\n",
    "\n",
    "def lr_scheduler(epoch, lr):\n",
    "    global base_lr\n",
    "    \n",
    "    new_lr = base_lr\n",
    "    if epoch <= 2:\n",
    "        pass\n",
    "    elif epoch > 2 and epoch <= 80:\n",
    "        new_lr = base_lr * 0.1\n",
    "    else:\n",
    "        new_lr = base_lr * 0.01\n",
    "    return new_lr\n",
    "\n",
    "def lr_scheduler2(epoch, lr):\n",
    "    #print( \"Learning rate:\", lr)\n",
    "    return lr\n",
    "\n",
    "callbacks = [LearningRateScheduler(lr_scheduler2)]\n",
    "\n",
    "batch_size = 128\n",
    "epochs = 100\n",
    "\n",
    "if data_augmentation:\n",
    "    # datagen\n",
    "    datagen = ImageDataGenerator(\n",
    "        featurewise_center=False,  # set input mean to 0 over the dataset\n",
    "        samplewise_center=False,  # set each sample mean to 0\n",
    "        featurewise_std_normalization=False,  # divide inputs by std of the dataset\n",
    "        samplewise_std_normalization=False,  # divide each input by its std\n",
    "        zca_whitening=False,  # apply ZCA whitening\n",
    "        rotation_range=15,  # randomly rotate images in the range (degrees, 0 to 180)\n",
    "        width_shift_range=0.1,  # randomly shift images horizontally (fraction of total width)\n",
    "        height_shift_range=0.1,  # randomly shift images vertically (fraction of total height)\n",
    "        horizontal_flip=True,  # randomly flip images\n",
    "        vertical_flip=False, # randomly flip images\n",
    "    ) \n",
    "    # (std, mean, and principal components if ZCA whitening is applied).\n",
    "    # datagen.fit(x_train)\n",
    "    print('train with data augmentation')\n",
    "    \n",
    "    for iter_e in range(num_epochs):\n",
    "        \n",
    "        opt.lr = lr_scheduler(iter_e, opt.lr)\n",
    "        history = vgg.fit_generator(generator=datagen.flow(x_train, y_train, batch_size=batch_size),\n",
    "                                    epochs=1,\n",
    "                                    callbacks=callbacks,\n",
    "                                    validation_data=(x_val, y_val)\n",
    "                                    )\n",
    "\n",
    "        x=activation_model(Separability_images)\n",
    "        train_gen = datagen.flow(x_train, y_train, batch_size=batch_size)\n",
    "        # Loss and Acc by epochs\n",
    "        train_loss_squence[iter_e],train_accuracy_squence[iter_e]=vgg.evaluate(train_gen)\n",
    "        test_loss_squence[iter_e],test_accuracy_squence[iter_e]=vgg.evaluate(x_val, y_val)\n",
    "\n",
    "        # LS of every layer's output\n",
    "        for layers_i in range(len(x)-reserved_layers):\n",
    "            LS_1_squence[layers_i,iter_e],LS_2_squence[layers_i,iter_e],J_w_squence[layers_i,iter_e],\\\n",
    "            LDA_squence[layers_i,iter_e]=W(x[layers_i+reserved_layers],Separability_labels.reshape(-1,1))    \n",
    "\n",
    "        print('**********'+'the',iter_e,'epochs has finished'+'**********')   \n",
    "\n",
    "    \n",
    "    \n",
    "    \n",
    "    \n",
    "    \n",
    "\n",
    "    \n",
    "    \n",
    "    \n",
    "    \n",
    "    \n",
    "    \n",
    "else:\n",
    "    print('train without data augmentation')\n",
    "    history = vgg.fit(x_train, y_train, \n",
    "                      batch_size=batch_size, epochs=epochs, \n",
    "                      callbacks=[reduce_lr],\n",
    "                      validation_data=(x_val, y_val)\n",
    "                      )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-01T16:15:31.149436Z",
     "start_time": "2023-05-01T16:15:31.087905Z"
    }
   },
   "outputs": [],
   "source": [
    "def plot_Separability_figure(layer_name_list,x_plot,LS_1_squence,LS_2_squence,J_w_squence,LDA_squence,LS_1_squence_base,LS_2_squence_base,J_w_squence_base,LDA_squence_base):\n",
    "\n",
    "    \n",
    "    #figure = plt.figure(1,figsize=(8,8))\n",
    "    figure = plt.figure(figsize=(16,20))\n",
    "    #layer_name_list = get_layer_name(model)[reserved_layers:]\n",
    "    \n",
    "    \n",
    "    # LS_0\n",
    "    ax1 = plt.subplot(2,2,1)\n",
    "\n",
    "    for i in range(len(LS_1_squence)-1):\n",
    "        if 'flatten' not in layer_name_list[i]:\n",
    "            plt.plot(x_plot,LS_1_squence[i,:],label=layer_name_list[i])\n",
    "    plt.plot(x_plot,LS_1_squence_base[0,:],'-k',label='base_line')\n",
    "    plt.title('$\\mathrm{LS}_1$')\n",
    "    #plt.legend(bbox_to_anchor=(1,0),loc=3)\n",
    "\n",
    "\n",
    "    # LS_1\n",
    "    ax2 = plt.subplot(2,2,2)\n",
    "\n",
    "    for i in range(len(LS_2_squence)-1):\n",
    "        if 'flatten' not in layer_name_list[i]:\n",
    "            plt.plot(x_plot,LS_2_squence[i,:],label=layer_name_list[i])\n",
    "    plt.plot(x_plot,LS_2_squence_base[0,:],'-k',label='base_line')\n",
    "    plt.title('$\\mathrm{LS}_2$')\n",
    "    plt.legend(bbox_to_anchor=(1,0),loc=3)\n",
    "\n",
    "\n",
    "    # LS_2\n",
    "    ax3 = plt.subplot(2,2,3)\n",
    "\n",
    "    for i in range(len(J_w_squence)-1):\n",
    "        if 'flatten' not in layer_name_list[i]:\n",
    "            plt.plot(x_plot,J_w_squence[i,:],label=layer_name_list[i])\n",
    "    plt.plot(x_plot,J_w_squence_base[0,:],'-k',label='base_line')\n",
    "    plt.title('$\\mathrm{J}_w$')\n",
    "    #plt.legend(bbox_to_anchor=(1,0),loc=3)\n",
    "\n",
    "    # J_w\n",
    "    ax4 = plt.subplot(2,2,4)\n",
    "\n",
    "    for i in range(len(LDA_squence)-1):\n",
    "        if 'flatten' not in layer_name_list[i]:\n",
    "            plt.plot(x_plot,LDA_squence[i,:],label=layer_name_list[i])\n",
    "    plt.plot(x_plot,LDA_squence_base[0,:],'-k',label='base_line')\n",
    "    plt.title('$\\mathrm{LDA}$')\n",
    "    plt.legend(bbox_to_anchor=(1,0),loc=3)\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    \n",
    "    \n",
    "    \n",
    "    return figure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-01T16:15:31.801791Z",
     "start_time": "2023-05-01T16:15:31.753447Z"
    }
   },
   "outputs": [],
   "source": [
    "import warnings\n",
    "import time\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-01T16:15:36.202433Z",
     "start_time": "2023-05-01T16:15:32.454952Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "#layer_name_list = get_layer_name(vgg)[reserved_layers:]\n",
    "layer_name_list = [\n",
    "    layer.name for layer in vgg.layers if 'input' not in layer.name\n",
    "]\n",
    "\n",
    "\n",
    "Separability_figure = plot_Separability_figure(layer_name_list,x_plot,LS_1_squence,LS_2_squence,J_w_squence,LDA_squence,LS_1_squence_base,LS_2_squence_base,J_w_squence_base,LDA_squence_base)\n",
    "net_figure = plot_net_figure(layer_name_list,x_plot,train_loss_squence,train_accuracy_squence,test_loss_squence,test_accuracy_squence)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-01T16:15:41.952868Z",
     "start_time": "2023-05-01T16:15:41.898002Z"
    }
   },
   "outputs": [],
   "source": [
    "info={'layer_name_list':layer_name_list,'x_plot':x_plot,'LS_1_squence':LS_1_squence,'LS_2_squence':LS_2_squence,'J_w_squence':J_w_squence,\n",
    "            'LDA_squence':LDA_squence,'train_loss_squence':train_loss_squence,'train_accuracy_squence':train_accuracy_squence,'test_loss_squence':test_loss_squence,'test_accuracy_squence':test_accuracy_squence,'LS_1_squence_base':LS_1_squence_base,'LS_2_squence_base':LS_2_squence_base,'J_w_squence_base':J_w_squence_base,\n",
    "            'LDA_squence_base':LDA_squence_base,}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tensorflow-gpu",
   "language": "python",
   "name": "test"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.0"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
