{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "690d7578",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T07:51:09.844036Z",
     "start_time": "2023-05-14T07:51:07.417227Z"
    }
   },
   "outputs": [],
   "source": [
    "from Linear_Separability_numpy_mul import *\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "import numpy as np\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e00f39f3",
   "metadata": {},
   "source": [
    "# Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ebe56390",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T07:51:10.104762Z",
     "start_time": "2023-05-14T07:51:09.847922Z"
    }
   },
   "outputs": [],
   "source": [
    "mnist = tf.keras.datasets.mnist\n",
    "(original_train_images,original_train_labels),(original_test_images,original_test_labels) = mnist.load_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a0b2515",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T07:51:10.315985Z",
     "start_time": "2023-05-14T07:51:10.108242Z"
    }
   },
   "outputs": [],
   "source": [
    "# Choice class = 0 or 1\n",
    "# all_index_train = np.where((original_train_labels==0 ) | (original_train_labels==1)| (original_train_labels==2))[0]\n",
    "# all_index_test = np.where((original_test_labels==0 ) | (original_test_labels==1) | (original_test_labels==2))[0]\n",
    "\n",
    "all_index_train = np.arange(len(original_train_images))\n",
    "all_index_test = np.arange(len(original_test_images))\n",
    "\n",
    "# Normalization [0,1]\n",
    "original_train_images, original_test_images = original_train_images / 255.0, original_test_images / 255.0\n",
    "\n",
    "# Random choice data for training from dataset\n",
    "# the number of data for training = 2000 \n",
    "# the number of data for testing = 1000\n",
    "np.random.seed(1234)\n",
    "train_index = all_index_train[np.random.randint(0,len(all_index_train),2000)]\n",
    "test_index = all_index_test[np.random.randint(0,len(all_index_test),1000)]\n",
    "\n",
    "train_images,train_labels = original_train_images[train_index,:], original_train_labels[train_index,]\n",
    "test_images,test_labels = original_test_images[test_index,:], original_test_labels[test_index,]\n",
    "\n",
    "\n",
    "train_images = train_images.astype('float32')\n",
    "test_images = test_images.astype('float32')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3719b816",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T07:51:10.323474Z",
     "start_time": "2023-05-14T07:51:10.319214Z"
    }
   },
   "outputs": [],
   "source": [
    "s_train_images = train_images.copy()\n",
    "s_train_labels = train_labels.copy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ec4f549",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1652176",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T07:51:10.330667Z",
     "start_time": "2023-05-14T07:51:10.326162Z"
    }
   },
   "outputs": [],
   "source": [
    "from tensorflow.keras.utils import to_categorical\n",
    "train_labels = to_categorical(train_labels, num_classes=10)\n",
    "test_labels = to_categorical(test_labels, num_classes=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be6868c0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "a83e84bb",
   "metadata": {},
   "source": [
    "# Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b604340",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T07:51:10.335834Z",
     "start_time": "2023-05-14T07:51:10.332960Z"
    }
   },
   "outputs": [],
   "source": [
    "# parameters\n",
    "batch_size = 32\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d62df1f",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T07:51:10.444082Z",
     "start_time": "2023-05-14T07:51:10.340110Z"
    }
   },
   "outputs": [],
   "source": [
    "#  cuda memory is force to increase by net training\n",
    "\n",
    "from tensorflow.compat.v1 import ConfigProto\n",
    "from tensorflow.compat.v1 import InteractiveSession\n",
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '/gpu:0,1'\n",
    "\n",
    "config=tf.compat.v1.ConfigProto()\n",
    "config.gpu_options.allow_growth=True\n",
    "sess=tf.compat.v1.Session(config=config) "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a5e11e4f",
   "metadata": {},
   "source": [
    "# network structure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83ada34f",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T07:51:10.572757Z",
     "start_time": "2023-05-14T07:51:10.451750Z"
    }
   },
   "outputs": [],
   "source": [
    "# generator of the choiced data\n",
    "\n",
    "train_datagen = tf.keras.preprocessing.image.ImageDataGenerator()\n",
    "train_generator = train_datagen.flow(train_images.reshape(-1, 28, 28, 1),\n",
    "                                     train_labels,\n",
    "                                     batch_size=batch_size)\n",
    "\n",
    "test_datagen = tf.keras.preprocessing.image.ImageDataGenerator()\n",
    "validation_generator = test_datagen.flow(test_images.reshape(-1, 28, 28, 1),\n",
    "                                         test_labels,\n",
    "                                         batch_size=batch_size)\n",
    "\n",
    "# the structure of network\n",
    "# here is a network with single hidden layer\n",
    "model = tf.keras.models.Sequential([\n",
    "    tf.keras.layers.InputLayer(input_shape=(28, 28, 1)),\n",
    "    tf.keras.layers.Flatten(),\n",
    "    tf.keras.layers.Dense(50, activation='relu'),\n",
    "    tf.keras.layers.Dense(50, activation='relu'),\n",
    "    tf.keras.layers.Dense(50, activation='relu'),\n",
    "    tf.keras.layers.Dense(50, activation='relu'),\n",
    "    tf.keras.layers.Dense(50, activation='relu'),\n",
    "    tf.keras.layers.Dense(10, activation='softmax')\n",
    "])\n",
    "\n",
    "model.compile(loss='categorical_crossentropy',\n",
    "              optimizer=tf.keras.optimizers.Adam(lr=1e-4),\n",
    "              metrics=['accuracy'])\n",
    "\n",
    "# get every layer's output of the network\n",
    "layer_outputs = [\n",
    "    layer.output for layer in model.layers if layer.name != 'flatten'\n",
    "]\n",
    "activation_model = tf.keras.models.Model(inputs=model.input,\n",
    "                                         outputs=layer_outputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b101006",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "da1212c2",
   "metadata": {},
   "source": [
    "# initial record matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a14e153",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T07:51:10.835400Z",
     "start_time": "2023-05-14T07:51:10.577157Z"
    }
   },
   "outputs": [],
   "source": [
    "# parameters (train)\n",
    "num_epochs = 100\n",
    "number_train = 2000\n",
    "batch_ep = 1\n",
    "# batch_size = 32\n",
    "# num_batches = int(num_train_data // batch_size * num_epochs)\n",
    "x_plot = np.arange(num_epochs)*batch_ep\n",
    "reserved_layers = 0\n",
    "\n",
    "for Separability_image,Separability_label in train_generator:\n",
    "    break\n",
    "Separability_label = Separability_label.reshape(-1,1)\n",
    "\n",
    "# initialize record matrix\n",
    "x=activation_model.predict(Separability_image)\n",
    "LS_1_squence = np.zeros((len(x)-reserved_layers,num_epochs))\n",
    "LS_2_squence = np.zeros((len(x)-reserved_layers,num_epochs))\n",
    "J_w_squence = np.zeros((len(x)-reserved_layers,num_epochs))\n",
    "LDA_squence = np.zeros((len(x)-reserved_layers,num_epochs))\n",
    "LS_1_squence_base = np.zeros((1,num_epochs))\n",
    "LS_2_squence_base = np.zeros((1,num_epochs))\n",
    "J_w_squence_base = np.zeros((1,num_epochs))\n",
    "LDA_squence_base = np.zeros((1,num_epochs))\n",
    "train_loss_squence = np.zeros((num_epochs,))\n",
    "train_accuracy_squence = np.zeros((num_epochs,))\n",
    "test_loss_squence = np.zeros((num_epochs,))\n",
    "test_accuracy_squence = np.zeros((num_epochs,))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c43169ac",
   "metadata": {},
   "source": [
    "# Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c447265f",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T08:06:49.506985Z",
     "start_time": "2023-05-14T07:51:10.840689Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# LS of original data\n",
    "LS_1_squence_base[0,:],LS_2_squence_base[0,:],\\\n",
    "J_w_squence_base[0,:],\\\n",
    "LDA_squence_base[0,:]=W(tf.constant(s_train_images),s_train_labels.reshape(-1,1))\n",
    "\n",
    "\n",
    "for iter_e in range(num_epochs):\n",
    "    \n",
    "    model.fit(train_generator, epochs=1, \n",
    "                    steps_per_epoch=train_images.shape[0] // 32,\n",
    "                    validation_data=(test_images.reshape(-1, 28, 28, 1), test_labels))\n",
    "    \n",
    "    # \n",
    "    x=activation_model(s_train_images)\n",
    "    \n",
    "    # Loss and Acc by epochs\n",
    "    train_loss_squence[iter_e],train_accuracy_squence[iter_e]=model.evaluate(train_generator)\n",
    "    test_loss_squence[iter_e],test_accuracy_squence[iter_e]=model.evaluate(validation_generator)\n",
    "    \n",
    "    # LS of every layer's output\n",
    "    for layers_i in range(len(x)-reserved_layers):\n",
    "        LS_1_squence[layers_i,iter_e],LS_2_squence[layers_i,iter_e],J_w_squence[layers_i,iter_e],\\\n",
    "        LDA_squence[layers_i,iter_e]=W(x[layers_i+reserved_layers],s_train_labels.reshape(-1,1))\n",
    "    \n",
    "    print('**********'+'the',iter_e,'epochs has finished'+'**********')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3d147e4",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T08:06:49.514887Z",
     "start_time": "2023-05-14T08:06:49.510557Z"
    }
   },
   "outputs": [],
   "source": [
    "import warnings\n",
    "import time\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d558bba",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T08:06:49.538048Z",
     "start_time": "2023-05-14T08:06:49.517871Z"
    }
   },
   "outputs": [],
   "source": [
    "def plot_Separability_figure(layer_name_list,x_plot,LS_1_squence,LS_2_squence,J_w_squence,LDA_squence,LS_1_squence_base,LS_2_squence_base,J_w_squence_base,LDA_squence_base):\n",
    "    \n",
    "    \n",
    "    #figure = plt.figure(1,figsize=(8,8))\n",
    "    figure = plt.figure(figsize=(16,10))\n",
    "    #layer_name_list = get_layer_name(model)[reserved_layers:]\n",
    "    \n",
    "    \n",
    "    # LS_0\n",
    "    ax1 = plt.subplot(2,2,1)\n",
    "\n",
    "    for i in range(len(LS_1_squence)-1):\n",
    "        if 'flatten' not in layer_name_list[i]:\n",
    "            plt.plot(x_plot,LS_1_squence[i,:],label=layer_name_list[i])\n",
    "    plt.plot(x_plot,LS_1_squence_base[0,:],'-k',label='base_line')\n",
    "    plt.title('$\\mathrm{LS}_1$')\n",
    "    #plt.legend(bbox_to_anchor=(1,0),loc=3)\n",
    "\n",
    "\n",
    "    # LS_1\n",
    "    ax2 = plt.subplot(2,2,2)\n",
    "\n",
    "    for i in range(len(LS_2_squence)-1):\n",
    "        if 'flatten' not in layer_name_list[i]:\n",
    "            plt.plot(x_plot,LS_2_squence[i,:],label=layer_name_list[i])\n",
    "    plt.plot(x_plot,LS_2_squence_base[0,:],'-k',label='base_line')\n",
    "    plt.title('$\\mathrm{LS}_2$')\n",
    "    plt.legend(bbox_to_anchor=(1,0),loc=3)\n",
    "\n",
    "\n",
    "    # LS_2\n",
    "    ax3 = plt.subplot(2,2,3)\n",
    "\n",
    "    for i in range(len(J_w_squence)-1):\n",
    "        if 'flatten' not in layer_name_list[i]:\n",
    "            plt.plot(x_plot,J_w_squence[i,:],label=layer_name_list[i])\n",
    "    plt.plot(x_plot,J_w_squence_base[0,:],'-k',label='base_line')\n",
    "    plt.title('$\\mathrm{J}_w$')\n",
    "    #plt.legend(bbox_to_anchor=(1,0),loc=3)\n",
    "\n",
    "    # J_w\n",
    "    ax4 = plt.subplot(2,2,4)\n",
    "\n",
    "    for i in range(len(LDA_squence)-1):\n",
    "        if 'flatten' not in layer_name_list[i]:\n",
    "            plt.plot(x_plot,LDA_squence[i,:],label=layer_name_list[i])\n",
    "    plt.plot(x_plot,LDA_squence_base[0,:],'-k',label='base_line')\n",
    "    plt.title('$\\mathrm{LDA}$')\n",
    "    plt.legend(bbox_to_anchor=(1,0),loc=3)\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    \n",
    "    \n",
    "    \n",
    "    return figure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb0d2b4c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T08:06:51.618511Z",
     "start_time": "2023-05-14T08:06:49.540902Z"
    }
   },
   "outputs": [],
   "source": [
    "from Linear_Separability_numpy_mul import *\n",
    "\n",
    "layer_name_list = get_layer_name(model)[reserved_layers:]\n",
    "Separability_figure = plot_Separability_figure(layer_name_list,x_plot,LS_1_squence,LS_2_squence,J_w_squence,LDA_squence,LS_1_squence_base,LS_2_squence_base,J_w_squence_base,LDA_squence_base)\n",
    "net_figure = plot_net_figure(layer_name_list,x_plot,train_loss_squence,train_accuracy_squence,test_loss_squence,test_accuracy_squence)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1366544",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T08:06:51.627954Z",
     "start_time": "2023-05-14T08:06:51.621269Z"
    }
   },
   "outputs": [],
   "source": [
    "info={'layer_name_list':layer_name_list,'x_plot':x_plot,'LS_1_squence':LS_1_squence,'LS_2_squence':LS_2_squence,'J_w_squence':J_w_squence,\n",
    "            'LDA_squence':LDA_squence,'train_loss_squence':train_loss_squence,'train_accuracy_squence':train_accuracy_squence,'test_loss_squence':test_loss_squence,'test_accuracy_squence':test_accuracy_squence,'LS_1_squence_base':LS_1_squence_base,'LS_2_squence_base':LS_2_squence_base,'J_w_squence_base':J_w_squence_base,\n",
    "            'LDA_squence_base':LDA_squence_base,}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "628ae843",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tensorflow-gpu",
   "language": "python",
   "name": "test"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.0"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
