{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "# Keras ResNet classifier for CIFAR10 test\n",
    "ResNet network for CIFAR10 network test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-12T14:37:51.811905Z",
     "start_time": "2023-05-12T14:37:49.587478Z"
    }
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras.models import Sequential\n",
    "from tensorflow.keras.layers import Dense, Activation\n",
    "from data_utils import *\n",
    "\n",
    "%matplotlib inline\n",
    "plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots\n",
    "plt.rcParams['image.interpolation'] = 'nearest'\n",
    "plt.rcParams['image.cmap'] = 'gray'\n",
    "\n",
    "import tensorflow as tf \n",
    "from tensorflow.keras import backend as k\n",
    "import os\n",
    "from tensorflow.compat.v1 import ConfigProto\n",
    "config = ConfigProto()\n",
    "# config.gpu_options.per_process_gpu_memory_fraction = 0.1\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '/gpu:0,1'\n",
    "\n",
    "config=tf.compat.v1.ConfigProto()\n",
    "config.gpu_options.allow_growth=True\n",
    "sess=tf.compat.v1.Session(config=config) \n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# get data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-12T14:37:56.484413Z",
     "start_time": "2023-05-12T14:37:51.815910Z"
    }
   },
   "outputs": [],
   "source": [
    "# get data\n",
    "cifar10_data = CIFAR10Data()\n",
    "\n",
    "x_train_o, y_train_o, x_test_o, y_test_o = cifar10_data.get_data(subtract_mean=True)\n",
    "\n",
    "all_index_train = np.where((np.argmax(y_train_o,axis=1)==0 ) | (np.argmax(y_train_o,axis=1)==1))[0]\n",
    "all_index_test = np.where((np.argmax(y_test_o,axis=1)==0 ) | (np.argmax(y_test_o,axis=1)==1))[0]\n",
    "\n",
    "np.random.seed(1234)\n",
    "LS_index = all_index_train[np.random.randint(0,len(all_index_train),500)]\n",
    "Separability_images,Separability_labels = x_train_o[LS_index,:], y_train_o[LS_index,]\n",
    "\n",
    "x_train,y_train = x_train_o[all_index_train,:],y_train_o[all_index_train,:]\n",
    "x_test,y_test = x_test_o[all_index_test,:],y_test_o[all_index_test,:]\n",
    "\n",
    "# ##########################################\n",
    "# y_train = np.argmax(y_train,axis=1)\n",
    "# y_test = np.argmax(y_test,axis=1)\n",
    "\n",
    "# from tensorflow.keras.utils import to_categorical\n",
    "# y_train = to_categorical(y_train, num_classes=2)\n",
    "# y_test = to_categorical(y_test, num_classes=2)\n",
    "# ##########################################\n",
    "\n",
    "num_train = int(x_train.shape[0] * 0.9)\n",
    "num_val = x_train.shape[0] - num_train\n",
    "mask = list(range(num_train, num_train+num_val))\n",
    "x_val = x_train[mask]\n",
    "y_val = y_train[mask]\n",
    "\n",
    "mask = list(range(num_train))\n",
    "x_train = x_train[mask]\n",
    "y_train = y_train[mask]\n",
    "\n",
    "print('num train:%d num val:%d' % (num_train, num_val))\n",
    "data = (x_train, y_train, x_val, y_val, x_test, y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## test with resnet20\n",
    "resnet20 is inffered in the ResNet paper."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-12T14:37:57.508502Z",
     "start_time": "2023-05-12T14:37:56.488074Z"
    }
   },
   "outputs": [],
   "source": [
    "from classifiers.ResNet import ResNet20ForCIFAR10\n",
    "from tensorflow.keras import losses\n",
    "from tensorflow.keras import optimizers\n",
    "\n",
    "weight_decay = 1e-4\n",
    "lr = 1e-3\n",
    "num_classes = 10\n",
    "resnet20 = ResNet20ForCIFAR10(input_shape=(32, 32, 3), classes=num_classes, weight_decay=weight_decay)\n",
    "opt = optimizers.SGD(lr=lr, momentum=0.9, nesterov=False)\n",
    "resnet20.compile(optimizer=opt,\n",
    "                 loss=losses.categorical_crossentropy,\n",
    "                 metrics=['accuracy'])\n",
    "resnet20.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-12T14:37:57.564787Z",
     "start_time": "2023-05-12T14:37:57.512872Z"
    },
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "layer_outputs = [\n",
    "    layer.output for layer in resnet20.layers if 'input' not in layer.name\n",
    "]\n",
    "activation_model = tf.keras.models.Model(inputs=resnet20.input,\n",
    "                                         outputs=layer_outputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-12T14:38:00.069650Z",
     "start_time": "2023-05-12T14:37:57.567825Z"
    },
    "run_control": {
     "marked": true
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from Linear_Separability_numpy import *\n",
    "\n",
    "# parameters (train)\n",
    "num_epochs = 100\n",
    "\n",
    "batch_ep = 1\n",
    "\n",
    "x_plot = np.arange(num_epochs)*batch_ep\n",
    "reserved_layers = 0\n",
    "\n",
    "\n",
    "\n",
    "# initialize record matrix\n",
    "x=activation_model.predict(Separability_images)\n",
    "LS_1_squence = np.zeros((len(x)-reserved_layers,num_epochs))\n",
    "LS_2_squence = np.zeros((len(x)-reserved_layers,num_epochs))\n",
    "J_w_squence = np.zeros((len(x)-reserved_layers,num_epochs))\n",
    "LDA_squence = np.zeros((len(x)-reserved_layers,num_epochs))\n",
    "LS_1_squence_base = np.zeros((1,num_epochs))\n",
    "LS_2_squence_base = np.zeros((1,num_epochs))\n",
    "J_w_squence_base = np.zeros((1,num_epochs))\n",
    "LDA_squence_base = np.zeros((1,num_epochs))\n",
    "train_loss_squence = np.zeros((num_epochs,))\n",
    "train_accuracy_squence = np.zeros((num_epochs,))\n",
    "test_loss_squence = np.zeros((num_epochs,))\n",
    "test_accuracy_squence = np.zeros((num_epochs,))\n",
    "\n",
    "Separability_labels = np.argmax(Separability_labels,axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-12T14:38:40.159790Z",
     "start_time": "2023-05-12T14:38:00.072702Z"
    },
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "LS_1_squence_base[0,:],LS_2_squence_base[0,:],\\\n",
    "J_w_squence_base[0,:],\\\n",
    "LDA_squence_base[0,:]=W(tf.constant(Separability_images/ 255.0),Separability_labels.reshape(-1,1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-12T14:38:40.237397Z",
     "start_time": "2023-05-12T14:38:40.166499Z"
    }
   },
   "outputs": [],
   "source": [
    "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
    "import matplotlib.pyplot as plt\n",
    "import tensorflow.keras.backend as K\n",
    "\n",
    "\n",
    "def plot_history(history):\n",
    "    \"\"\"\n",
    "    plot train epoch history and acc\n",
    "    :param history: train history object returned by CIFAR10Solver.train()\n",
    "    \"\"\"\n",
    "    plt.plot(history.history['loss'])\n",
    "    plt.plot(history.history['val_loss'])\n",
    "    plt.xlabel('epoch')\n",
    "    plt.ylabel('Loss value')\n",
    "    plt.legend(['train', 'test'], loc='upper left')\n",
    "    plt.show()\n",
    "\n",
    "    plt.plot(history.history['acc'])\n",
    "    plt.plot(history.history['val_acc'])\n",
    "    plt.xlabel('epoch')\n",
    "    plt.ylabel('acc value')\n",
    "    plt.legend(['train', 'test'], loc='upper left')\n",
    "    plt.show()\n",
    "\n",
    "    \n",
    "\n",
    "from tensorflow.keras.callbacks import LearningRateScheduler\n",
    "\n",
    "base_lr = 1e-3\n",
    "\n",
    "def lr_scheduler(epoch, lr):\n",
    "    global base_lr\n",
    "    \n",
    "    new_lr = base_lr\n",
    "    if epoch <= 2:\n",
    "        pass\n",
    "    elif epoch > 2 and epoch <= 80:\n",
    "        new_lr = base_lr * 0.1\n",
    "    else:\n",
    "        new_lr = base_lr * 0.01\n",
    "    return new_lr\n",
    "\n",
    "def lr_scheduler2(epoch, lr):\n",
    "    #print( \"Learning rate:\", lr)\n",
    "    return lr\n",
    "\n",
    "callbacks = [LearningRateScheduler(lr_scheduler2)]\n",
    "\n",
    "\n",
    "class CIFAR10Solver(object):\n",
    "    \"\"\"\n",
    "    A CIFAR10Solver encapsulates all the logic nessary for training cifar10 classifiers.The train model is defined\n",
    "    outside, you must pass it to init().\n",
    "\n",
    "    The solver train the model, plot loss and aac history, and test on the test data.\n",
    "\n",
    "    Example usage might look something like this.\n",
    "\n",
    "    model = MyAwesomeModel(opt=SGD, losses='categorical_crossentropy',  metrics=['acc'])\n",
    "    model.compile(...)\n",
    "    model.summary()\n",
    "    solver = CIFAR10Solver(model)\n",
    "    history = solver.train()\n",
    "    plotHistory(history)\n",
    "    solver.test()\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, model, data):\n",
    "        \"\"\"\n",
    "\n",
    "        :param model: A model object conforming to the API described above\n",
    "        :param data:  A tuple of training, validation and test data from CIFAR10Data\n",
    "        \"\"\"\n",
    "        self.model = model\n",
    "        self.X_train, self.Y_train, self.X_val, self.Y_val, self.X_test, self.Y_test = data\n",
    "\n",
    "    def __on_epoch_end(self, epoch, logs=None):\n",
    "        print(K.eval(self.model.optimizer.lr))\n",
    "\n",
    "    def train(self, epochs=100, batch_size=128, data_augmentation=True, callbacks=None):\n",
    "        if data_augmentation:\n",
    "            # datagen\n",
    "            datagen = ImageDataGenerator(\n",
    "                featurewise_center=False,  # set input mean to 0 over the dataset\n",
    "                samplewise_center=False,  # set each sample mean to 0\n",
    "                featurewise_std_normalization=False,  # divide inputs by std of the dataset\n",
    "                samplewise_std_normalization=False,  # divide each input by its std\n",
    "                zca_whitening=False,  # apply ZCA whitening\n",
    "                # rotation_range=15,  # randomly rotate images in the range (degrees, 0 to 180)\n",
    "                width_shift_range=4,  # randomly shift images horizontally (fraction of total width)\n",
    "                height_shift_range=4,  # randomly shift images vertically (fraction of total height)\n",
    "                horizontal_flip=True,  # randomly flip images\n",
    "                vertical_flip=False,  # randomly flip images\n",
    "            )\n",
    "            # (std, mean, and principal components if ZCA whitening is applied).\n",
    "            # datagen.fit(x_train)\n",
    "            print('train with data augmentation')\n",
    "            train_gen = datagen.flow(self.X_train, self.Y_train, batch_size=batch_size)\n",
    "            \n",
    " \n",
    "            for iter_e in range(epochs):\n",
    "        \n",
    "                opt.lr = lr_scheduler(iter_e, opt.lr)\n",
    "                history = self.model.fit_generator(generator=train_gen,\n",
    "                                  epochs=1,\n",
    "                                  validation_data=(x_val, y_val),\n",
    "                                  callbacks=callbacks\n",
    "                                 )\n",
    "                \n",
    "                x=activation_model(Separability_images)\n",
    "\n",
    "                # Loss and Acc by epochs\n",
    "                train_loss_squence[iter_e],train_accuracy_squence[iter_e]=self.model.evaluate(train_gen)\n",
    "                test_loss_squence[iter_e],test_accuracy_squence[iter_e]=self.model.evaluate(self.X_val, self.Y_val)\n",
    "\n",
    "                # LS of every layer's output\n",
    "                for layers_i in range(len(x)-reserved_layers):\n",
    "                    LS_1_squence[layers_i,iter_e],LS_2_squence[layers_i,iter_e],J_w_squence[layers_i,iter_e],\\\n",
    "                    LDA_squence[layers_i,iter_e]=W(x[layers_i+reserved_layers],Separability_labels.reshape(-1,1))    \n",
    "            \n",
    "                print('**********'+'the',iter_e,'epochs has finished'+'**********')\n",
    "        else:\n",
    "            print('train without data augmentation')\n",
    "            history = self.model.fit(self.X_train, self.Y_train,\n",
    "                                     batch_size=batch_size, epochs=epochs,\n",
    "                                     callbacks=callbacks,\n",
    "                                     validation_data=(self.X_val, self.Y_val),\n",
    "                                     )\n",
    "        return history\n",
    "\n",
    "    def test(self):\n",
    "        loss, acc = self.model.evaluate(self.X_test, self.Y_test)\n",
    "        print('test data loss:%.2f acc:%.4f' % (loss, acc))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2023-05-12T14:36:40.528Z"
    },
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "#from cifar10_solver import *\n",
    "# from keras.callbacks import ReduceLROnPlateau\n",
    "from tensorflow.keras.callbacks import LearningRateScheduler\n",
    "\n",
    "def lr_scheduler1(epoch):\n",
    "    new_lr = lr\n",
    "    if epoch <= 91:\n",
    "        pass\n",
    "    elif epoch > 91 and epoch <= 137:\n",
    "        new_lr = lr * 0.1\n",
    "    else:\n",
    "        new_lr = lr * 0.01\n",
    "    print('new lr:%.2e' % new_lr)\n",
    "    return new_lr \n",
    "\n",
    "reduce_lr1 = LearningRateScheduler(lr_scheduler1)\n",
    "# reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1,\n",
    "#                               patience=10, min_lr=1e-6, verbose=1)\n",
    "\n",
    "solver = CIFAR10Solver(resnet20, data)\n",
    "history = solver.train(epochs=182, batch_size=128, data_augmentation=True, callbacks=[reduce_lr1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2023-05-12T14:36:40.529Z"
    }
   },
   "outputs": [],
   "source": [
    "import warnings\n",
    "import time\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2023-05-12T14:36:40.530Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "layer_name_list = get_layer_name(resnet20)[reserved_layers:]\n",
    "Separability_figure = plot_Separability_figure(layer_name_list,x_plot,LS_1_squence,LS_2_squence,J_w_squence,LDA_squence,LS_1_squence_base,LS_2_squence_base,J_w_squence_base,LDA_squence_base)\n",
    "net_figure = plot_net_figure(layer_name_list,x_plot,train_loss_squence,train_accuracy_squence,test_loss_squence,test_accuracy_squence)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2023-05-12T14:36:40.534Z"
    }
   },
   "outputs": [],
   "source": [
    "info={'layer_name_list':layer_name_list,'x_plot':x_plot,'LS_1_squence':LS_1_squence,'LS_2_squence':LS_2_squence,'J_w_squence':J_w_squence,\n",
    "            'LDA_squence':LDA_squence,'train_loss_squence':train_loss_squence,'train_accuracy_squence':train_accuracy_squence,'test_loss_squence':test_loss_squence,'test_accuracy_squence':test_accuracy_squence,'LS_1_squence_base':LS_1_squence_base,'LS_2_squence_base':LS_2_squence_base,'J_w_squence_base':J_w_squence_base,\n",
    "            'LDA_squence_base':LDA_squence_base,}"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tensorflow-gpu",
   "language": "python",
   "name": "test"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.0"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
