{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "690d7578",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-02T15:40:20.055738Z",
     "start_time": "2023-04-02T15:40:16.527912Z"
    }
   },
   "outputs": [],
   "source": [
    "from Linear_Separability_numpy import *\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "import numpy as np\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e00f39f3",
   "metadata": {},
   "source": [
    "# Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ebe56390",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-02T15:40:21.199206Z",
     "start_time": "2023-04-02T15:40:20.060198Z"
    }
   },
   "outputs": [],
   "source": [
    "# Load data\n",
    "cifar10 = tf.keras.datasets.cifar10\n",
    "(original_train_images,original_train_labels),(original_test_images,original_test_labels) = cifar10.load_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a0b2515",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-02T15:40:22.030741Z",
     "start_time": "2023-04-02T15:40:21.210754Z"
    }
   },
   "outputs": [],
   "source": [
    "# Choice class = 0 or 1\n",
    "all_index_train = np.where((original_train_labels==0 ) | (original_train_labels==1))[0]\n",
    "all_index_test = np.where((original_test_labels==0 ) | (original_test_labels==1))[0]\n",
    "\n",
    "# Normalization [0,1]\n",
    "original_train_images, original_test_images = original_train_images / 255.0, original_test_images / 255.0\n",
    "\n",
    "# Random choice data for training from dataset\n",
    "# the number of data for training = 2000 \n",
    "# the number of data for testing = 1000\n",
    "np.random.seed(1234)\n",
    "train_index = all_index_train[np.random.randint(0,len(all_index_train),2000)]\n",
    "test_index = all_index_test[np.random.randint(0,len(all_index_test),1000)]\n",
    "\n",
    "train_images,train_labels = original_train_images[train_index,:], original_train_labels[train_index,]\n",
    "test_images,test_labels = original_test_images[test_index,:], original_test_labels[test_index,]\n",
    "\n",
    "\n",
    "train_images = train_images.astype('float32')\n",
    "test_images = test_images.astype('float32')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a83e84bb",
   "metadata": {},
   "source": [
    "# Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b604340",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-02T15:40:22.038671Z",
     "start_time": "2023-04-02T15:40:22.035300Z"
    }
   },
   "outputs": [],
   "source": [
    "# parameters\n",
    "batch_size = 32\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d62df1f",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-02T15:40:22.191188Z",
     "start_time": "2023-04-02T15:40:22.041841Z"
    }
   },
   "outputs": [],
   "source": [
    "#  cuda memory is force to increase by net training\n",
    "\n",
    "from tensorflow.compat.v1 import ConfigProto\n",
    "from tensorflow.compat.v1 import InteractiveSession\n",
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '/gpu:0,1'\n",
    "\n",
    "config=tf.compat.v1.ConfigProto()\n",
    "config.gpu_options.allow_growth=True\n",
    "sess=tf.compat.v1.Session(config=config) "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a5e11e4f",
   "metadata": {},
   "source": [
    "# network structure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83ada34f",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-02T15:40:22.415356Z",
     "start_time": "2023-04-02T15:40:22.201132Z"
    }
   },
   "outputs": [],
   "source": [
    "# generator of the choiced data\n",
    "\n",
    "train_datagen = tf.keras.preprocessing.image.ImageDataGenerator()\n",
    "train_generator = train_datagen.flow(train_images.reshape(-1, 32, 32, 3),\n",
    "                                     train_labels,\n",
    "                                     batch_size=batch_size)\n",
    "\n",
    "test_datagen = tf.keras.preprocessing.image.ImageDataGenerator()\n",
    "validation_generator = test_datagen.flow(test_images.reshape(-1, 32, 32, 3),\n",
    "                                         test_labels,\n",
    "                                         batch_size=batch_size)\n",
    "\n",
    "# the structure of network\n",
    "# here is a network with single hidden layer\n",
    "model = tf.keras.models.Sequential([\n",
    "    tf.keras.layers.InputLayer(input_shape=(32, 32, 3)),\n",
    "    tf.keras.layers.Flatten(),\n",
    "    tf.keras.layers.Dense(50, activation='relu'),\n",
    "    tf.keras.layers.Dense(50, activation='relu'),\n",
    "    tf.keras.layers.Dense(50, activation='relu'),\n",
    "    tf.keras.layers.Dense(50, activation='relu'),\n",
    "    tf.keras.layers.Dense(50, activation='relu'),\n",
    "    tf.keras.layers.Dense(50, activation='relu'),\n",
    "    tf.keras.layers.Dense(50, activation='relu'),\n",
    "    tf.keras.layers.Dense(50, activation='relu'),\n",
    "    tf.keras.layers.Dense(50, activation='relu'),\n",
    "    tf.keras.layers.Dense(50, activation='relu'),\n",
    "    tf.keras.layers.Dense(1, activation='sigmoid')\n",
    "])\n",
    "\n",
    "model.compile(loss='binary_crossentropy',\n",
    "              optimizer=tf.keras.optimizers.Adam(lr=1e-4),\n",
    "              metrics=['accuracy'])\n",
    "\n",
    "# get every layer's output of the network\n",
    "layer_outputs = [\n",
    "    layer.output for layer in model.layers if layer.name != 'flatten'\n",
    "]\n",
    "activation_model = tf.keras.models.Model(inputs=model.input,\n",
    "                                         outputs=layer_outputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b101006",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "da1212c2",
   "metadata": {},
   "source": [
    "# initial record matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a14e153",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-02T15:40:22.810626Z",
     "start_time": "2023-04-02T15:40:22.419045Z"
    }
   },
   "outputs": [],
   "source": [
    "# parameters (train)\n",
    "num_epochs = 100\n",
    "number_train = 2000\n",
    "batch_ep = 1\n",
    "# batch_size = 32\n",
    "# num_batches = int(num_train_data // batch_size * num_epochs)\n",
    "x_plot = np.arange(num_epochs)*batch_ep\n",
    "reserved_layers = 0\n",
    "\n",
    "for Separability_image,Separability_label in train_generator:\n",
    "    break\n",
    "Separability_label = Separability_label.reshape(-1,1)\n",
    "\n",
    "# initialize record matrix\n",
    "x=activation_model.predict(Separability_image)\n",
    "LS_1_squence = np.zeros((len(x)-reserved_layers,num_epochs))\n",
    "LS_2_squence = np.zeros((len(x)-reserved_layers,num_epochs))\n",
    "J_w_squence = np.zeros((len(x)-reserved_layers,num_epochs))\n",
    "LDA_squence = np.zeros((len(x)-reserved_layers,num_epochs))\n",
    "LS_1_squence_base = np.zeros((1,num_epochs))\n",
    "LS_2_squence_base = np.zeros((1,num_epochs))\n",
    "J_w_squence_base = np.zeros((1,num_epochs))\n",
    "LDA_squence_base = np.zeros((1,num_epochs))\n",
    "train_loss_squence = np.zeros((num_epochs,))\n",
    "train_accuracy_squence = np.zeros((num_epochs,))\n",
    "test_loss_squence = np.zeros((num_epochs,))\n",
    "test_accuracy_squence = np.zeros((num_epochs,))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c43169ac",
   "metadata": {},
   "source": [
    "# Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c447265f",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-02T15:53:12.371478Z",
     "start_time": "2023-04-02T15:40:22.821141Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# LS of original data\n",
    "LS_1_squence_base[0,:],LS_2_squence_base[0,:],\\\n",
    "J_w_squence_base[0,:],\\\n",
    "LDA_squence_base[0,:]=W(tf.constant(train_images),train_labels.reshape(-1,1))\n",
    "\n",
    "\n",
    "for iter_e in range(num_epochs):\n",
    "    \n",
    "    model.fit(train_generator, epochs=1, \n",
    "                    steps_per_epoch=train_images.shape[0] // 32,\n",
    "                    validation_data=(test_images.reshape(-1, 32, 32, 3), test_labels))\n",
    "    \n",
    "    # \n",
    "    x=activation_model(train_images)\n",
    "    \n",
    "    # Loss and Acc by epochs\n",
    "    train_loss_squence[iter_e],train_accuracy_squence[iter_e]=model.evaluate(train_generator)\n",
    "    test_loss_squence[iter_e],test_accuracy_squence[iter_e]=model.evaluate(validation_generator)\n",
    "    \n",
    "    # LS of every layer's output\n",
    "    for layers_i in range(len(x)-reserved_layers):\n",
    "        LS_1_squence[layers_i,iter_e],LS_2_squence[layers_i,iter_e],J_w_squence[layers_i,iter_e],\\\n",
    "        LDA_squence[layers_i,iter_e]=W(x[layers_i+reserved_layers],train_labels.reshape(-1,1))\n",
    "    \n",
    "    print('**********'+'the',iter_e,'epochs has finished'+'**********')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3d147e4",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-02T15:53:12.387461Z",
     "start_time": "2023-04-02T15:53:12.380110Z"
    }
   },
   "outputs": [],
   "source": [
    "import warnings\n",
    "import time\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "935c32db",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb0d2b4c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-02T15:57:59.078271Z",
     "start_time": "2023-04-02T15:57:56.555685Z"
    }
   },
   "outputs": [],
   "source": [
    "layer_name_list = get_layer_name(model)[reserved_layers:]\n",
    "Separability_figure = plot_Separability_figure(layer_name_list,x_plot,LS_1_squence,LS_2_squence,J_w_squence,LDA_squence,LS_1_squence_base,LS_2_squence_base,J_w_squence_base,LDA_squence_base)\n",
    "net_figure = plot_net_figure(layer_name_list,x_plot,train_loss_squence,train_accuracy_squence,test_loss_squence,test_accuracy_squence)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1366544",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-02T15:58:08.592277Z",
     "start_time": "2023-04-02T15:58:08.577722Z"
    }
   },
   "outputs": [],
   "source": [
    "info={'layer_name_list':layer_name_list,'x_plot':x_plot,'LS_1_squence':LS_1_squence,'LS_2_squence':LS_2_squence,'J_w_squence':J_w_squence,\n",
    "            'LDA_squence':LDA_squence,'train_loss_squence':train_loss_squence,'train_accuracy_squence':train_accuracy_squence,'test_loss_squence':test_loss_squence,'test_accuracy_squence':test_accuracy_squence,'LS_1_squence_base':LS_1_squence_base,'LS_2_squence_base':LS_2_squence_base,'J_w_squence_base':J_w_squence_base,\n",
    "            'LDA_squence_base':LDA_squence_base,}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "628ae843",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tensorflow-gpu",
   "language": "python",
   "name": "test"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.0"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
