{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "view-in-github"
   },
   "source": [
    "<a href=\"https://colab.research.google.com/github/pravinkr/alexnet-cifar10-using-keras/blob/master/cifar_10_with_Alexnet.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-05T11:18:22.913536Z",
     "start_time": "2023-05-05T11:18:22.884788Z"
    },
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 80
    },
    "colab_type": "code",
    "id": "0uoD7EUdjgXP",
    "outputId": "585d5484-d21c-4226-dad9-e1918d512ab4"
   },
   "outputs": [],
   "source": [
    "#Simple CNN model for CIFAR-10 dataset\n",
    "import numpy as np\n",
    "\n",
    "# Simple CNN model for CIFAR-10\n",
    "import numpy as np\n",
    "import os\n",
    "from tensorflow.keras.datasets import cifar10, cifar100\n",
    "from tensorflow.keras.models import Sequential\n",
    "from tensorflow.keras.layers import Dense, Conv2D\n",
    "from tensorflow.keras.layers import Dropout\n",
    "from tensorflow.keras.layers import Flatten\n",
    "from tensorflow.keras.constraints import max_norm\n",
    "from tensorflow.keras.optimizers import SGD, Adam\n",
    "from tensorflow.keras.layers import Convolution2D\n",
    "from tensorflow.keras.layers import MaxPooling2D\n",
    "from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler\n",
    "from tensorflow.keras.callbacks import ReduceLROnPlateau\n",
    "from tensorflow.keras.callbacks import EarlyStopping\n",
    "from tensorflow.keras.regularizers import l2\n",
    "from tensorflow.keras import backend as K\n",
    "import tensorflow as tf\n",
    "from data_utils import *\n",
    "\n",
    "#from tensorflow.keras.utils import np_utils\n",
    "import matplotlib.pyplot as plt\n",
    "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
    "%matplotlib inline\n",
    "\n",
    "import skimage\n",
    "from skimage.util import img_as_ubyte\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-05T11:28:58.452711Z",
     "start_time": "2023-05-05T11:28:53.074787Z"
    },
    "run_control": {
     "marked": true
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CIFAR10 Training data shape: (50000, 32, 32, 3)\n",
      "CIFAR10 Training label shape (50000, 1)\n",
      "CIFAR10 Test data shape (10000, 32, 32, 3)\n",
      "CIFAR10 Test label shape (10000, 1)\n",
      "num train:9000 num val:1000\n"
     ]
    }
   ],
   "source": [
    "# get data\n",
    "cifar10_data = CIFAR10Data()\n",
    "\n",
    "x_train_o, y_train_o, x_test_o, y_test_o = cifar10_data.get_data(subtract_mean=True)\n",
    "\n",
    "all_index_train = np.where((np.argmax(y_train_o,axis=1)==0 ) | (np.argmax(y_train_o,axis=1)==1))[0]\n",
    "all_index_test = np.where((np.argmax(y_test_o,axis=1)==0 ) | (np.argmax(y_test_o,axis=1)==1))[0]\n",
    "\n",
    "np.random.seed(1234)\n",
    "LS_index = all_index_train[np.random.randint(0,len(all_index_train),2000)]\n",
    "\n",
    "Separability_images,Separability_labels = x_train_o[LS_index,:], y_train_o[LS_index,]\n",
    "\n",
    "######################Separability_base_images##################################################\n",
    "cifar10 = tf.keras.datasets.cifar10\n",
    "(x_train_s, y_train_s), (x_test_s, y_test_s) = cifar10.load_data()\n",
    "\n",
    "Separability_images_base,Separability_labels_base = x_train_s[LS_index,:], y_train_s[LS_index,]\n",
    "Separability_images_base = Separability_images_base/255.0\n",
    "Separability_images_base = Separability_images_base.astype('float32')\n",
    "######################Separability_base_images##################################################\n",
    "\n",
    "\n",
    "\n",
    "x_train,y_train = x_train_o[all_index_train,:],y_train_o[all_index_train,:]\n",
    "x_test,y_test = x_test_o[all_index_test,:],y_test_o[all_index_test,:]\n",
    "\n",
    "\n",
    "# ##########################################\n",
    "# y_train = np.argmax(y_train,axis=1)\n",
    "# y_test = np.argmax(y_test,axis=1)\n",
    "\n",
    "# from tensorflow.keras.utils import to_categorical\n",
    "# y_train = to_categorical(y_train, num_classes=2)\n",
    "# y_test = to_categorical(y_test, num_classes=2)\n",
    "# ##########################################\n",
    "\n",
    "num_train = int(x_train.shape[0] * 0.9)\n",
    "num_val = x_train.shape[0] - num_train\n",
    "mask = list(range(num_train, num_train+num_val))\n",
    "x_val = x_train[mask]\n",
    "y_val = y_train[mask]\n",
    "\n",
    "mask = list(range(num_train))\n",
    "x_train = x_train[mask]\n",
    "y_train = y_train[mask]\n",
    "\n",
    "print('num train:%d num val:%d' % (num_train, num_val))\n",
    "data = (x_train, y_train, x_val, y_val, x_test, y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-05T11:18:28.581603Z",
     "start_time": "2023-05-05T11:17:27.011Z"
    },
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 51
    },
    "colab_type": "code",
    "id": "k34XYTn_gbVF",
    "outputId": "f7eb15db-1aa0-43b3-b44e-e7d194041c59"
   },
   "outputs": [],
   "source": [
    "#Defining Variables\n",
    "\n",
    "#Data set information\n",
    "DATASET = 'cifar-10'\n",
    "#DATASET = 'cifar-100'\n",
    "input_shape=(32,32,3)\n",
    "\n",
    "num_classes = 10\n",
    "\n",
    "model_type = 'Alexnet'\n",
    "\n",
    "epochs = 25\n",
    "lrate = 0.01\n",
    "decay = lrate/epochs\n",
    "batch_size = 32\n",
    "\n",
    "data_augmentation = True\n",
    "flow_from_dir = False\n",
    "\n",
    "subtract_mean = True\n",
    "seed = 7\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-05T11:18:28.584312Z",
     "start_time": "2023-05-05T11:17:28.060Z"
    },
    "colab": {},
    "colab_type": "code",
    "id": "hXFuLmd6bjpY"
   },
   "outputs": [],
   "source": [
    "#Define Alexnet Model\n",
    "def AlexnetModel(input_shape, num_classes):\n",
    "    model = Sequential()\n",
    "    model.add(\n",
    "        Conv2D(filters=96,\n",
    "               kernel_size=(3, 3),\n",
    "               strides=(4, 4),\n",
    "               input_shape=input_shape,\n",
    "               activation='relu'))\n",
    "    model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))\n",
    "    model.add(Conv2D(256, (5, 5), padding='same', activation='relu'))\n",
    "    model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))\n",
    "    model.add(Conv2D(384, (3, 3), padding='same', activation='relu'))\n",
    "    model.add(Conv2D(384, (3, 3), padding='same', activation='relu'))\n",
    "    model.add(Conv2D(256, (3, 3), padding='same', activation='relu'))\n",
    "    model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))\n",
    "\n",
    "    model.add(Flatten())\n",
    "    model.add(Dense(4096, activation='relu'))\n",
    "    model.add(Dropout(0.4))\n",
    "    model.add(Dense(4096, activation='relu'))\n",
    "    model.add(Dropout(0.4))\n",
    "    model.add(Dense(num_classes, activation='softmax'))\n",
    "\n",
    "    #model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])\n",
    "\n",
    "    #model.summary()\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-05T11:18:28.586788Z",
     "start_time": "2023-05-05T11:17:28.514Z"
    },
    "colab": {},
    "colab_type": "code",
    "id": "eRNzvi-qZXvl"
   },
   "outputs": [],
   "source": [
    "def get_model(input_shape, num_classes):\n",
    "    # Create the model - Normal model with CNN, dropouts and MaxPooling.\n",
    "    # This model gives accuracy of 77% on the test set after 25 epochs\n",
    "\n",
    "    model = Sequential()\n",
    "\n",
    "    #model.add(Conv2D(32,(3,3),input_shape=(32,32,3),padding='same',activation='relu'))\n",
    "    model.add(\n",
    "        Conv2D(32, (3, 3),\n",
    "               input_shape=input_shape,\n",
    "               padding='same',\n",
    "               activation='relu'))\n",
    "    model.add(Dropout(0.2))\n",
    "\n",
    "    model.add(Conv2D(32, (3, 3), activation='relu', padding='same'))\n",
    "    model.add(MaxPooling2D(pool_size=(2, 2)))\n",
    "\n",
    "    model.add(Conv2D(64, (3, 3), activation='relu', padding='same'))\n",
    "    model.add(Dropout(0.2))\n",
    "\n",
    "    model.add(Conv2D(64, (3, 3), activation='relu', padding='same'))\n",
    "    model.add(MaxPooling2D(pool_size=(2, 2)))\n",
    "\n",
    "    model.add(Conv2D(128, (3, 3), activation='relu', padding='same'))\n",
    "    model.add(Dropout(0.2))\n",
    "\n",
    "    model.add(Conv2D(128, (3, 3), activation='relu', padding='same'))\n",
    "    model.add(MaxPooling2D(pool_size=(2, 2)))\n",
    "\n",
    "    model.add(Flatten())\n",
    "    model.add(Dropout(0.2))\n",
    "\n",
    "    model.add(Dense(1024, activation='relu'))\n",
    "    model.add(Dropout(0.2))\n",
    "\n",
    "    model.add(Dense(512, activation='relu'))\n",
    "    model.add(Dropout(0.2))\n",
    "\n",
    "    model.add(Dense(num_classes, activation='softmax'))\n",
    "\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-05T11:18:28.589147Z",
     "start_time": "2023-05-05T11:17:29.282Z"
    },
    "colab": {},
    "colab_type": "code",
    "id": "p2cGLs6n9l0L"
   },
   "outputs": [],
   "source": [
    "def lr_schedule(epoch):\n",
    "    \"\"\"Learning Rate Schedule\n",
    "\n",
    "    Learning rate is scheduled to be reduced after 80, 120, 160, 180 epochs.\n",
    "    Called automatically every epoch as part of callbacks during training.\n",
    "\n",
    "    # Arguments\n",
    "        epoch (int): The number of epochs\n",
    "\n",
    "    # Returns\n",
    "        lr (float32): learning rate\n",
    "    \"\"\"\n",
    "    lr = 1e-3\n",
    "    if epoch > 180:\n",
    "        lr *= 0.5e-3\n",
    "    elif epoch > 160:\n",
    "        lr *= 1e-3\n",
    "    elif epoch > 120:\n",
    "        lr *= 1e-2\n",
    "    elif epoch > 80:\n",
    "        lr *= 1e-1\n",
    "    print('Learning rate: ', lr)\n",
    "    return lr\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-05T11:18:28.591299Z",
     "start_time": "2023-05-05T11:17:30.006Z"
    },
    "colab": {},
    "colab_type": "code",
    "id": "sdnE9Sd0-NbF",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Prepare callbacks for model saving and for learning rate adjustment.\n",
    "\n",
    "early_stopping = EarlyStopping(monitor='val_loss', min_delta=0.001, patience=5, verbose=0, mode='auto', baseline=None, restore_best_weights=True)\n",
    "\n",
    "\n",
    "\n",
    "lr_scheduler = LearningRateScheduler(lr_schedule)\n",
    "\n",
    "lr_reducer = ReduceLROnPlateau(factor=np.sqrt(0.1),\n",
    "                               cooldown=0,\n",
    "                               patience=5,\n",
    "                               min_lr=0.5e-6)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-05T11:18:29.297394Z",
     "start_time": "2023-05-05T11:18:29.147355Z"
    },
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 972
    },
    "colab_type": "code",
    "id": "AbLu2lefac_h",
    "outputId": "7b1e97aa-92ac-457f-e905-ad55b736cedd",
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model Summary of  Alexnet\n",
      "Model: \"sequential_1\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "conv2d_5 (Conv2D)            (None, 8, 8, 96)          2688      \n",
      "_________________________________________________________________\n",
      "max_pooling2d_3 (MaxPooling2 (None, 4, 4, 96)          0         \n",
      "_________________________________________________________________\n",
      "conv2d_6 (Conv2D)            (None, 4, 4, 256)         614656    \n",
      "_________________________________________________________________\n",
      "max_pooling2d_4 (MaxPooling2 (None, 2, 2, 256)         0         \n",
      "_________________________________________________________________\n",
      "conv2d_7 (Conv2D)            (None, 2, 2, 384)         885120    \n",
      "_________________________________________________________________\n",
      "conv2d_8 (Conv2D)            (None, 2, 2, 384)         1327488   \n",
      "_________________________________________________________________\n",
      "conv2d_9 (Conv2D)            (None, 2, 2, 256)         884992    \n",
      "_________________________________________________________________\n",
      "max_pooling2d_5 (MaxPooling2 (None, 1, 1, 256)         0         \n",
      "_________________________________________________________________\n",
      "flatten_1 (Flatten)          (None, 256)               0         \n",
      "_________________________________________________________________\n",
      "dense_3 (Dense)              (None, 4096)              1052672   \n",
      "_________________________________________________________________\n",
      "dropout_2 (Dropout)          (None, 4096)              0         \n",
      "_________________________________________________________________\n",
      "dense_4 (Dense)              (None, 4096)              16781312  \n",
      "_________________________________________________________________\n",
      "dropout_3 (Dropout)          (None, 4096)              0         \n",
      "_________________________________________________________________\n",
      "dense_5 (Dense)              (None, 10)                40970     \n",
      "=================================================================\n",
      "Total params: 21,589,898\n",
      "Trainable params: 21,589,898\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "None\n"
     ]
    }
   ],
   "source": [
    "model = AlexnetModel(input_shape,num_classes)\n",
    "\n",
    "#optimizer = SGD(lr=lrate, momentum=0.9, decay=decay, nesterov=False)\n",
    "optimizer = Adam(lr=0.001, beta_1=0.9, beta_2=0.999, amsgrad=False)\n",
    "model.compile(loss= 'categorical_crossentropy' , optimizer=optimizer, metrics=[ 'accuracy' ])\n",
    "print(\"Model Summary of \",model_type)\n",
    "print(model.summary())\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-05T11:18:29.831927Z",
     "start_time": "2023-05-05T11:18:29.818606Z"
    }
   },
   "outputs": [],
   "source": [
    "layer_outputs = [\n",
    "    layer.output for layer in model.layers if 'input' not in layer.name\n",
    "]\n",
    "activation_model = tf.keras.models.Model(inputs=model.input,\n",
    "                                         outputs=layer_outputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-05T11:29:00.241388Z",
     "start_time": "2023-05-05T11:28:59.611678Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from Linear_Separability_numpy import *\n",
    "\n",
    "# parameters (train)\n",
    "num_epochs = 100\n",
    "\n",
    "batch_ep = 1\n",
    "\n",
    "x_plot = np.arange(num_epochs)*batch_ep\n",
    "reserved_layers = 0\n",
    "\n",
    "\n",
    "\n",
    "# initialize record matrix\n",
    "x=activation_model.predict(Separability_images)\n",
    "LS_1_squence = np.zeros((len(x)-reserved_layers,num_epochs))\n",
    "LS_2_squence = np.zeros((len(x)-reserved_layers,num_epochs))\n",
    "J_w_squence = np.zeros((len(x)-reserved_layers,num_epochs))\n",
    "LDA_squence = np.zeros((len(x)-reserved_layers,num_epochs))\n",
    "LS_1_squence_base = np.zeros((1,num_epochs))\n",
    "LS_2_squence_base = np.zeros((1,num_epochs))\n",
    "J_w_squence_base = np.zeros((1,num_epochs))\n",
    "LDA_squence_base = np.zeros((1,num_epochs))\n",
    "train_loss_squence = np.zeros((num_epochs,))\n",
    "train_accuracy_squence = np.zeros((num_epochs,))\n",
    "test_loss_squence = np.zeros((num_epochs,))\n",
    "test_accuracy_squence = np.zeros((num_epochs,))\n",
    "\n",
    "Separability_labels = np.argmax(Separability_labels,axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-05T11:25:13.521712Z",
     "start_time": "2023-05-05T11:25:05.284326Z"
    }
   },
   "outputs": [],
   "source": [
    "LS_1_squence_base[0,:],LS_2_squence_base[0,:],\\\n",
    "J_w_squence_base[0,:],\\\n",
    "LDA_squence_base[0,:]=W(tf.constant(Separability_images_base),Separability_labels_base.reshape(-1,1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-05T11:34:50.196603Z",
     "start_time": "2023-05-05T11:34:33.975167Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(<tf.Tensor: shape=(), dtype=float32, numpy=0.49127802>,\n",
       " 0.6583708,\n",
       " 288168.47,\n",
       " array([[0.00040401]], dtype=float32))"
      ]
     },
     "execution_count": 61,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "W(tf.constant(Separability_images_base),Separability_labels_base.reshape(-1,1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-05T11:30:54.207071Z",
     "start_time": "2023-05-05T11:30:54.183777Z"
    },
    "code_folding": [
     0
    ],
    "run_control": {
     "marked": true
    }
   },
   "outputs": [],
   "source": [
    "#计算LS_1度量所用函数\n",
    "def LS_1(w,M,r_1,r_2):\n",
    "    \n",
    "    LS_1_value_one = tf.reduce_sum(tf.sign(np.dot(w,M.T)))\n",
    "    LS_1_value_two = tf.reduce_sum(tf.sign(np.dot(-w,M.T)))\n",
    "    LS_1_value = max(LS_1_value_one,LS_1_value_two)/(r_1*r_2)\n",
    "    \n",
    "    return LS_1_value\n",
    "\n",
    "#计算LS_2度量所用函数\n",
    "def LS_2(w,M,r_1,r_2):\n",
    "    \n",
    "    L2_up = np.abs(np.sum(np.dot(w,M.T)))\n",
    "    L2_down = np.sum(np.abs(np.dot(w,M.T)))\n",
    "    LS_2_value = L2_up/L2_down\n",
    "    \n",
    "    return LS_2_value\n",
    "    \n",
    "    \n",
    "#计算Jw度量所用函数\n",
    "def J_w(w,M,r_1,r_2):\n",
    "    \n",
    "    J_w_up = np.square(np.sum(np.dot(w,M.T)))\n",
    "    J_w_down = np.sum(np.square(np.dot(w,M.T)))\n",
    "    J_w_value = J_w_up/J_w_down\n",
    "    \n",
    "    return J_w_value\n",
    "\n",
    "#计算LDA度量所用函数\n",
    "def LDA(w,new_data_A,new_data_B):\n",
    "    \n",
    "    w = w.T\n",
    "    \n",
    "    A_class = new_data_A.numpy()\n",
    "    B_class = new_data_B.numpy()\n",
    "    A_class = A_class.T\n",
    "    B_class = B_class.T\n",
    "\n",
    "    nu_a = np.mean(A_class,axis=1).reshape(-1,1)\n",
    "    nu_b = np.mean(B_class,axis=1).reshape(-1,1)\n",
    "\n",
    "    A_c = A_class-np.repeat(nu_a,A_class.shape[1],axis=1)\n",
    "    B_c = B_class-np.repeat(nu_b,B_class.shape[1],axis=1)\n",
    "\n",
    "    S_w = np.dot(A_c,A_c.T)+np.dot(B_c,B_c.T)\n",
    "    \n",
    "    S_b = np.dot(nu_a-nu_b,(nu_a-nu_b).T)\n",
    "\n",
    "\n",
    "    LDA_value_up = np.dot(np.dot(w.T,S_b),w)\n",
    "    LDA_value_down = np.dot(np.dot(w.T,S_w),w)\n",
    "    LDA_value = LDA_value_up/LDA_value_down\n",
    "    \n",
    "    return LDA_value\n",
    "\n",
    "\n",
    "#该函数通过输入每层的输出和标签值，得到LS_1，LS_2，Jw\n",
    "def W(original_X,original_Y):\n",
    "    # 数据转换为numpy\n",
    "    # Label是列向量\n",
    "\n",
    "    original_X = tf.constant(original_X)\n",
    "    original_Y = tf.constant(original_Y)\n",
    "\n",
    "    data_A = tf.gather(original_X, axis=0, indices=tf.where(original_Y==1)[:,0])\n",
    "    data_B = tf.gather(original_X, axis=0, indices=tf.where(original_Y==0)[:,0])\n",
    "\n",
    "    r_1 = len(data_A)\n",
    "    r_2 = len(data_B)\n",
    "\n",
    "    new_data_A = tf.reshape(data_A,(data_A.shape[0],-1))\n",
    "    new_data_B = tf.reshape(data_B,(data_B.shape[0],-1))\n",
    "\n",
    "    if new_data_A.shape[1]>=10000 or new_data_B.shape[1]>=10000:\n",
    "        new_data_A = tf.reduce_mean(data_A,axis=-1)\n",
    "        new_data_B = tf.reduce_mean(data_B,axis=-1)\n",
    "        new_data_A = tf.reshape(new_data_A,(new_data_A.shape[0],-1))\n",
    "        new_data_B = tf.reshape(new_data_B,(new_data_B.shape[0],-1))\n",
    "\n",
    "    M = np.zeros((r_1*r_2,new_data_A.shape[1]),dtype='float32')\n",
    "\n",
    "    index_base = np.arange(r_2)\n",
    "\n",
    "    # M行数是r_1*r_2 列数是每个样本对应输出展开的维数\n",
    "    for i in range(r_1):\n",
    "        M[index_base+i*r_2,:]=new_data_A[i]-new_data_B  \n",
    "\n",
    "\n",
    "\n",
    "    m =np.sum(M,axis=0).reshape(1,-1)\n",
    "    w = m/np.linalg.norm(m)\n",
    "    \n",
    "    LS_1_value = LS_1(w,M,r_1,r_2)\n",
    "    LS_2_value = LS_2(w,M,r_1,r_2)\n",
    "    J_w_value = J_w(w,M,r_1,r_2)\n",
    "    LDA_value = LDA(w,new_data_A,new_data_B)\n",
    "    \n",
    "    return LS_1_value,LS_2_value,J_w_value,LDA_value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-05T11:34:29.989161Z",
     "start_time": "2023-05-05T11:34:29.950497Z"
    },
    "code_folding": []
   },
   "outputs": [],
   "source": [
    "#计算LS_1度量所用函数2\n",
    "def LS_1(w,M,r_1,r_2):\n",
    "    \n",
    "    LS_1_value_one = tf.reduce_sum(tf.sign(np.dot(w,M.T)))\n",
    "    LS_1_value_two = tf.reduce_sum(tf.sign(np.dot(-w,M.T)))\n",
    "    LS_1_value = max(LS_1_value_one,LS_1_value_two)/(r_1*r_2)\n",
    "    \n",
    "    return LS_1_value\n",
    "\n",
    "#计算LS_2度量所用函数\n",
    "def LS_2(w,M,r_1,r_2):\n",
    "    \n",
    "    L2_up = np.abs(np.sum(np.dot(w,M.T)))\n",
    "    L2_down = np.sum(np.abs(np.dot(w,M.T)))\n",
    "    LS_2_value = L2_up/L2_down\n",
    "    \n",
    "    return LS_2_value\n",
    "    \n",
    "    \n",
    "#计算Jw度量所用函数\n",
    "def J_w(w,M,r_1,r_2):\n",
    "    \n",
    "    J_w_up = np.square(np.sum(np.dot(w,M.T)))\n",
    "    J_w_down = np.sum(np.square(np.dot(w,M.T)))\n",
    "    J_w_value = J_w_up/J_w_down\n",
    "    \n",
    "    return J_w_value\n",
    "\n",
    "#计算LDA度量所用函数\n",
    "def LDA(w,new_data_A,new_data_B,r_1,r_2):\n",
    "    \n",
    "    w = w.T\n",
    "    \n",
    "    A_class = new_data_A.numpy()\n",
    "    B_class = new_data_B.numpy()\n",
    "    A_class = A_class.T\n",
    "    B_class = B_class.T\n",
    "\n",
    "    nu_a = np.mean(A_class,axis=1).reshape(-1,1)\n",
    "    nu_b = np.mean(B_class,axis=1).reshape(-1,1)\n",
    "\n",
    "    A_c = A_class-np.repeat(nu_a,A_class.shape[1],axis=1)\n",
    "    B_c = B_class-np.repeat(nu_b,B_class.shape[1],axis=1)\n",
    "\n",
    "    S_w = np.dot(A_c,A_c.T)+np.dot(B_c,B_c.T)\n",
    "    \n",
    "    S_b = np.dot(nu_a-nu_b,(nu_a-nu_b).T)\n",
    "\n",
    "\n",
    "    LDA_value_up = np.dot(np.dot(w.T,S_b),w)\n",
    "    LDA_value_down = np.dot(np.dot(w.T,S_w),w)\n",
    "    LDA_value = LDA_value_up/LDA_value_down\n",
    "    \n",
    "    return LDA_value\n",
    "\n",
    "\n",
    "#该函数通过输入每层的输出和标签值，得到LS_1，LS_2，Jw\n",
    "def W(original_X,original_Y):\n",
    "    # 数据转换为numpy\n",
    "    # Label是列向量\n",
    "\n",
    "    original_X = tf.constant(original_X)\n",
    "    original_Y = tf.constant(original_Y)\n",
    "\n",
    "    data_A = tf.gather(original_X, axis=0, indices=tf.where(original_Y==1)[:,0])\n",
    "    data_B = tf.gather(original_X, axis=0, indices=tf.where(original_Y==0)[:,0])\n",
    "\n",
    "    r_1 = len(data_A)\n",
    "    r_2 = len(data_B)\n",
    "\n",
    "    new_data_A = tf.reshape(data_A,(data_A.shape[0],-1))\n",
    "    new_data_B = tf.reshape(data_B,(data_B.shape[0],-1))\n",
    "\n",
    "    if new_data_A.shape[1]>=1000 or new_data_B.shape[1]>=1000:\n",
    "        new_data_A = tf.reduce_mean(data_A,axis=-1)\n",
    "        new_data_B = tf.reduce_mean(data_B,axis=-1)\n",
    "        new_data_A = tf.reshape(new_data_A,(new_data_A.shape[0],-1))\n",
    "        new_data_B = tf.reshape(new_data_B,(new_data_B.shape[0],-1))\n",
    "\n",
    "    M = np.zeros((r_1*r_2,new_data_A.shape[1]),dtype='float32')\n",
    "\n",
    "    index_base = np.arange(r_2)\n",
    "\n",
    "    # M行数是r_1*r_2 列数是每个样本对应输出展开的维数\n",
    "    for i in range(r_1):\n",
    "        M[index_base+i*r_2,:]=new_data_A[i]-new_data_B  \n",
    "\n",
    "\n",
    "\n",
    "    m =np.sum(M,axis=0).reshape(1,-1)\n",
    "    w = m/np.linalg.norm(m)\n",
    "    \n",
    "    LS_1_value = LS_1(w,M,r_1,r_2)\n",
    "    LS_2_value = LS_2(w,M,r_1,r_2)\n",
    "    J_w_value = J_w(w,M,r_1,r_2)\n",
    "    LDA_value = LDA(w,new_data_A,new_data_B,r_1,r_2)\n",
    "    \n",
    "    return LS_1_value,LS_2_value,J_w_value,LDA_value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2023-05-03T16:59:27.703Z"
    },
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "colab_type": "code",
    "id": "AIEd9V4KMYR7",
    "outputId": "58a18034-df34-4209-f617-410ca42103e5",
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From <timed exec>:50: Model.fit_generator (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use Model.fit, which supports generators.\n",
      "282/282 [==============================] - 8s 28ms/step - loss: 0.6000 - accuracy: 0.7891 - val_loss: 0.4512 - val_accuracy: 0.7760 - lr: 0.0010\n",
      "282/282 [==============================] - 8s 28ms/step - loss: 0.4510 - accuracy: 0.7820\n",
      "32/32 [==============================] - 0s 4ms/step - loss: 0.4512 - accuracy: 0.7760\n",
      "**********the 0 epochs has finished**********\n",
      "163/282 [================>.............] - ETA: 3s - loss: 0.3650 - accuracy: 0.8475"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
    "from tensorflow.keras.callbacks import ReduceLROnPlateau\n",
    "from tensorflow.keras.callbacks import LearningRateScheduler\n",
    "# fit data with data augmentation or not\n",
    "data_augmentation = True\n",
    "\n",
    "\n",
    "base_lr = 1e-3\n",
    "\n",
    "def lr_scheduler(epoch, lr):\n",
    "    global base_lr\n",
    "    \n",
    "    new_lr = base_lr\n",
    "    if epoch <= 2:\n",
    "        pass\n",
    "    elif epoch > 2 and epoch <= 80:\n",
    "        new_lr = base_lr * 0.1\n",
    "    else:\n",
    "        new_lr = base_lr * 0.01\n",
    "    return new_lr\n",
    "\n",
    "def lr_scheduler2(epoch, lr):\n",
    "    #print( \"Learning rate:\", lr)\n",
    "    return lr\n",
    "\n",
    "callbacks = [LearningRateScheduler(lr_scheduler2)]\n",
    "\n",
    "datagen = ImageDataGenerator(\n",
    "    featurewise_center=False,  # set input mean to 0 over the dataset\n",
    "    samplewise_center=False,  # set each sample mean to 0\n",
    "    featurewise_std_normalization=False,  # divide inputs by std of the dataset\n",
    "    samplewise_std_normalization=False,  # divide each input by its std\n",
    "    zca_whitening=False,  # apply ZCA whitening\n",
    "    rotation_range=15,  # randomly rotate images in the range (degrees, 0 to 180)\n",
    "    width_shift_range=0.1,  # randomly shift images horizontally (fraction of total width)\n",
    "    height_shift_range=0.1,  # randomly shift images vertically (fraction of total height)\n",
    "    horizontal_flip=True,  # randomly flip images\n",
    "    vertical_flip=False, # randomly flip images\n",
    ") \n",
    "\n",
    "for iter_e in range(num_epochs):\n",
    "    \n",
    "    optimizer.lr = lr_scheduler(iter_e, optimizer.lr)\n",
    "    \n",
    "    train_gen = datagen.flow(x_train, y_train, batch_size=batch_size)\n",
    "    \n",
    "    history = model.fit_generator(generator=train_gen,\n",
    "                                  epochs=1,\n",
    "                                  validation_data=(x_val, y_val),\n",
    "                                  callbacks=callbacks\n",
    "                                 )\n",
    "        \n",
    "    x=activation_model(Separability_images)\n",
    "    #train_gen = datagen.flow(x_train, y_train, batch_size=batch_size)\n",
    "    # Loss and Acc by epochs\n",
    "    train_loss_squence[iter_e],train_accuracy_squence[iter_e]=model.evaluate(train_gen)\n",
    "    test_loss_squence[iter_e],test_accuracy_squence[iter_e]=model.evaluate(x_val, y_val)\n",
    "\n",
    "    # LS of every layer's output\n",
    "    for layers_i in range(len(x)-reserved_layers):\n",
    "        LS_1_squence[layers_i,iter_e],LS_2_squence[layers_i,iter_e],J_w_squence[layers_i,iter_e],\\\n",
    "        LDA_squence[layers_i,iter_e]=W(x[layers_i+reserved_layers],Separability_labels.reshape(-1,1))    \n",
    "\n",
    "    print('**********'+'the',iter_e,'epochs has finished'+'**********')   \n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2023-05-03T16:59:27.704Z"
    }
   },
   "outputs": [],
   "source": [
    "import warnings\n",
    "import time\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2023-05-03T16:59:27.705Z"
    }
   },
   "outputs": [],
   "source": [
    "layer_name_list = get_layer_name(model)[reserved_layers:]\n",
    "Separability_figure = plot_Separability_figure(layer_name_list,x_plot,LS_1_squence,LS_2_squence,J_w_squence,LDA_squence,LS_1_squence_base,LS_2_squence_base,J_w_squence_base,LDA_squence_base)\n",
    "net_figure = plot_net_figure(layer_name_list,x_plot,train_loss_squence,train_accuracy_squence,test_loss_squence,test_accuracy_squence)\n",
    "\n",
    "time_tuple = time.localtime(time.time())\n",
    "figure_name = \"Alex-结束时间为{}年{}月{}日{}点{}分{}秒.png\".format(time_tuple[0],time_tuple[1],time_tuple[2],time_tuple[3],time_tuple[4],time_tuple[5])\n",
    "Separability_figure.savefig('./Separability_figure/'+figure_name,dpi=100,bbox_inches = 'tight')\n",
    "net_figure.savefig('./net_figure/'+figure_name,dpi=100,bbox_inches = 'tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2023-05-03T16:59:27.707Z"
    }
   },
   "outputs": [],
   "source": [
    "info={'layer_name_list':layer_name_list,'x_plot':x_plot,'LS_1_squence':LS_1_squence,'LS_2_squence':LS_2_squence,'J_w_squence':J_w_squence,\n",
    "            'LDA_squence':LDA_squence,'train_loss_squence':train_loss_squence,'train_accuracy_squence':train_accuracy_squence,'test_loss_squence':test_loss_squence,'test_accuracy_squence':test_accuracy_squence,'LS_1_squence_base':LS_1_squence_base,'LS_2_squence_base':LS_2_squence_base,'J_w_squence_base':J_w_squence_base,\n",
    "            'LDA_squence_base':LDA_squence_base,}\n",
    "#将结果进行保存\n",
    "import pickle\n",
    "file_name = \"Alex-结束时间为{}年{}月{}日{}点{}分{}秒.pkl\".format(time_tuple[0],time_tuple[1],time_tuple[2],time_tuple[3],time_tuple[4],time_tuple[5])\n",
    "f = open('./saved_data/'+file_name,'wb')\n",
    "pickle.dump(info,f)\n",
    "f.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "include_colab_link": true,
   "machine_shape": "hm",
   "name": "cifar-10-with-Alexnet.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "tensorflow-gpu",
   "language": "python",
   "name": "test"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.0"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
