{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "import numpy as np\n",
    "from Linear_Separability_numpy import *\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cifar10 = tf.keras.datasets.cifar10\n",
    "(original_train_images,original_train_labels),(original_test_images,original_test_labels) = cifar10.load_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "all_index_train = np.where((original_train_labels==0 ) | (original_train_labels==1))[0]\n",
    "all_index_test = np.where((original_test_labels==0 ) | (original_test_labels==1))[0]\n",
    "\n",
    "original_train_images, original_test_images = original_train_images / 255.0, original_test_images / 255.0\n",
    "\n",
    "np.random.seed(1234)\n",
    "\n",
    "train_index = all_index_train[np.random.randint(0,len(all_index_train),2000)]\n",
    "test_index = all_index_test[np.random.randint(0,len(all_index_test),1000)]\n",
    "\n",
    "train_images,train_labels = original_train_images[train_index,:], original_train_labels[train_index,]\n",
    "test_images,test_labels = original_test_images[test_index,:], original_test_labels[test_index,]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 32"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.compat.v1 import ConfigProto\n",
    "from tensorflow.compat.v1 import InteractiveSession\n",
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '/gpu:0,1'\n",
    "\n",
    "config=tf.compat.v1.ConfigProto()\n",
    "config.gpu_options.allow_growth=True\n",
    "sess=tf.compat.v1.Session(config=config) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_datagen = tf.keras.preprocessing.image.ImageDataGenerator()\n",
    "train_generator = train_datagen.flow(train_images.reshape(-1, 32, 32, 3), train_labels, batch_size=batch_size)\n",
    "\n",
    "test_datagen = tf.keras.preprocessing.image.ImageDataGenerator()\n",
    "validation_generator = test_datagen.flow(test_images.reshape(-1, 32, 32, 3), test_labels, batch_size=batch_size)\n",
    "\n",
    "\n",
    "model = tf.keras.models.Sequential([\n",
    "    tf.keras.layers.InputLayer(input_shape=(32, 32, 3)),\n",
    "    tf.keras.layers.Conv2D(8, (3, 3), (1, 1), padding='same',activation=\"relu\"), #32*32*16\n",
    "    tf.keras.layers.MaxPooling2D((2, 2),padding='same'), \n",
    "    tf.keras.layers.Conv2D(8, (3, 3), strides=2,padding='same', activation=\"relu\"), #8*8*32\n",
    "    tf.keras.layers.MaxPooling2D((2, 2),padding='same'),\n",
    "    tf.keras.layers.Flatten(),\n",
    "    tf.keras.layers.Dense(50, activation='sigmoid'),\n",
    "    tf.keras.layers.Dense(1, activation='sigmoid')])\n",
    "model.compile(loss='binary_crossentropy',optimizer=tf.keras.optimizers.Adam(lr = 1e-4),metrics=['accuracy'])\n",
    "layer_outputs = [layer.output for layer in model.layers] \n",
    "activation_model = tf.keras.models.Model(inputs=model.input, outputs=layer_outputs)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_epochs = 200\n",
    "number_train = 2000\n",
    "batch_ep = 1\n",
    "# batch_size = 32\n",
    "# num_batches = int(num_train_data // batch_size * num_epochs)\n",
    "x_plot = np.arange(num_epochs)*batch_ep\n",
    "reserved_layers = 0\n",
    "Separability_image_shape = (number_train, 28,28,1)\n",
    "\n",
    "\n",
    "for Separability_image,Separability_label in train_generator:\n",
    "    break\n",
    "Separability_label = Separability_label.reshape(-1,1)\n",
    "\n",
    "\n",
    "x=activation_model.predict(Separability_image)\n",
    "LS_1_squence = np.zeros((len(x)-reserved_layers,num_epochs))\n",
    "LS_2_squence = np.zeros((len(x)-reserved_layers,num_epochs))\n",
    "J_w_squence = np.zeros((len(x)-reserved_layers,num_epochs))\n",
    "LDA_squence = np.zeros((len(x)-reserved_layers,num_epochs))\n",
    "LS_1_squence_base = np.zeros((1,num_epochs))\n",
    "LS_2_squence_base = np.zeros((1,num_epochs))\n",
    "J_w_squence_base = np.zeros((1,num_epochs))\n",
    "LDA_squence_base = np.zeros((1,num_epochs))\n",
    "train_loss_squence = np.zeros((num_epochs,))\n",
    "train_accuracy_squence = np.zeros((num_epochs,))\n",
    "test_loss_squence = np.zeros((num_epochs,))\n",
    "test_accuracy_squence = np.zeros((num_epochs,))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for iter_e in range(num_epochs):\n",
    "    \n",
    "    model.fit(train_generator, epochs=1, \n",
    "                    steps_per_epoch=train_images.shape[0] // 32,\n",
    "                    validation_data=(test_images.reshape(-1, 32, 32, 3), test_labels))\n",
    "    \n",
    "    x=activation_model(train_images)\n",
    "    LS_1_squence_base[0,iter_e],LS_2_squence_base[0,iter_e],\\\n",
    "    J_w_squence_base[0,iter_e],\\\n",
    "    LDA_squence_base[0,iter_e]=W(train_images,train_labels.reshape(-1,1))\n",
    "            \n",
    "    train_loss_squence[iter_e],train_accuracy_squence[iter_e]=model.evaluate(train_generator)\n",
    "    test_loss_squence[iter_e],test_accuracy_squence[iter_e]=model.evaluate(validation_generator)\n",
    "    \n",
    "    for layers_i in range(len(x)-reserved_layers):\n",
    "        LS_1_squence[layers_i,iter_e],LS_2_squence[layers_i,iter_e],J_w_squence[layers_i,iter_e],\\\n",
    "        LDA_squence[layers_i,iter_e]=W(x[layers_i+reserved_layers].numpy(),train_labels.reshape(-1,1))\n",
    "    \n",
    "    print('**********'+'the',iter_e,'epochs has finished'+'**********')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "import time\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "layer_name_list = get_layer_name(model)[reserved_layers:]\n",
    "Separability_figure = plot_Separability_figure(layer_name_list,x_plot,LS_1_squence,LS_2_squence,J_w_squence,LDA_squence,LS_1_squence_base,LS_2_squence_base,J_w_squence_base,LDA_squence_base)\n",
    "net_figure = plot_net_figure(layer_name_list,x_plot,train_loss_squence,train_accuracy_squence,test_loss_squence,test_accuracy_squence)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "info={'layer_name_list':layer_name_list,'x_plot':x_plot,'LS_1_squence':LS_1_squence,'LS_2_squence':LS_2_squence,'J_w_squence':J_w_squence,\n",
    "            'LDA_squence':LDA_squence,'train_loss_squence':train_loss_squence,'train_accuracy_squence':train_accuracy_squence,'test_loss_squence':test_loss_squence,'test_accuracy_squence':test_accuracy_squence,'LS_1_squence_base':LS_1_squence_base,'LS_2_squence_base':LS_2_squence_base,'J_w_squence_base':J_w_squence_base,\n",
    "            'LDA_squence_base':LDA_squence_base,}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tensorflow-gpu",
   "language": "python",
   "name": "test"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.0"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
