{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Simple CNN model for CIFAR-10 dataset\n",
    "import numpy as np\n",
    "\n",
    "# Simple CNN model for CIFAR-10\n",
    "import numpy as np\n",
    "import os\n",
    "from tensorflow.keras.datasets import cifar10, cifar100\n",
    "from tensorflow.keras.models import Sequential\n",
    "from tensorflow.keras.layers import Dense, Conv2D\n",
    "from tensorflow.keras.layers import Dropout\n",
    "from tensorflow.keras.layers import Flatten\n",
    "from tensorflow.keras.constraints import max_norm\n",
    "from tensorflow.keras.optimizers import SGD, Adam\n",
    "from tensorflow.keras.layers import Convolution2D\n",
    "from tensorflow.keras.layers import MaxPooling2D\n",
    "from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler\n",
    "from tensorflow.keras.callbacks import ReduceLROnPlateau\n",
    "from tensorflow.keras.callbacks import EarlyStopping\n",
    "from tensorflow.keras.regularizers import l2\n",
    "from tensorflow.keras import backend as K\n",
    "import tensorflow as tf\n",
    "\n",
    "\n",
    "#from tensorflow.keras.utils import np_utils\n",
    "import matplotlib.pyplot as plt\n",
    "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
    "%matplotlib inline\n",
    "\n",
    "import skimage\n",
    "from skimage.util import img_as_ubyte\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "mnist = tf.keras.datasets.mnist\n",
    "(x_train_o, y_train_o), (x_test_o, y_test_o) = mnist.load_data()\n",
    "\n",
    "x_train_o = x_train_o.reshape(x_train_o.shape[0],-1)/255.0\n",
    "x_test_o = x_test_o.reshape(x_test_o.shape[0],-1)/255.0\n",
    "\n",
    "x_train_o = x_train_o.astype('float32')\n",
    "x_test_o = x_test_o.astype('float32')\n",
    "\n",
    "train_x,train_y = x_train_o.copy(),y_train_o.copy()\n",
    "test_x,test_y = x_test_o.copy(),y_test_o.copy()\n",
    "\n",
    "\n",
    "all_index_train = np.where((y_train_o==0) + (y_train_o==1))[0]\n",
    "all_index_test = np.where((y_train_o==0) + (y_train_o==1))[0]\n",
    "\n",
    "train_index = all_index_train[np.random.randint(0,len(all_index_train),2000)]\n",
    "\n",
    "np.random.seed(1234)\n",
    "LS_index = train_index[np.random.randint(0,len(train_index),500)]\n",
    "\n",
    "Separability_images,Separability_labels = train_x[LS_index,:], train_y[LS_index,]\n",
    "\n",
    "from tensorflow.keras.utils import to_categorical\n",
    "\n",
    "train_y  = to_categorical(train_y)\n",
    "#train_x,train_y = train_x[train_index,:],train_y[train_index,:]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "layers=[100]\n",
    "outputs=10\n",
    "rbm_iters=[100,100,100,100,100]\n",
    "rbm_lr=[0.01,0.01,0.01,0.01,0.01]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# layers=[100]\n",
    "# outputs=10\n",
    "# rbm_iters=[50]\n",
    "# rbm_lr=[0.01]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn import linear_model, datasets\n",
    "from sklearn.metrics import classification_report\n",
    "from sklearn.neural_network import BernoulliRBM\n",
    "from tensorflow.keras.models import Sequential\n",
    "from tensorflow.keras.layers import Dense, Activation\n",
    "from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Sequential()\n",
    "for i in range(len(layers)):\n",
    "\n",
    "    if i == 0:\n",
    "        model.add(\n",
    "            Dense(layers[i],\n",
    "                  activation='relu',\n",
    "                  input_dim=train_x.shape[1],\n",
    "                  name='rbm_{}'.format(i)))\n",
    "    else:\n",
    "        model.add(\n",
    "            Dense(layers[i],\n",
    "                  activation='relu',\n",
    "                  name='rbm_{}'.format(i)))\n",
    "\n",
    "model.add(Dense(outputs, activation='softmax'))\n",
    "model.compile(optimizer='Adam',\n",
    "              loss='categorical_crossentropy',\n",
    "              metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "layer_outputs = [\n",
    "    layer.output for layer in model.layers if 'input' not in layer.name\n",
    "]\n",
    "activation_model = tf.keras.models.Model(inputs=model.input,\n",
    "                                         outputs=layer_outputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from Linear_Separability_numpy import *\n",
    "\n",
    "\n",
    "num_epochs = sum(rbm_iters)+20\n",
    "\n",
    "batch_ep = 1\n",
    "\n",
    "x_plot = np.arange(num_epochs)*batch_ep\n",
    "reserved_layers = 0\n",
    "\n",
    "\n",
    "\n",
    "# initialize record matrix\n",
    "x=activation_model.predict(Separability_images)\n",
    "LS_1_squence = np.zeros((len(x)-reserved_layers,num_epochs))\n",
    "LS_2_squence = np.zeros((len(x)-reserved_layers,num_epochs))\n",
    "J_w_squence = np.zeros((len(x)-reserved_layers,num_epochs))\n",
    "LDA_squence = np.zeros((len(x)-reserved_layers,num_epochs))\n",
    "LS_1_squence_base = np.zeros((1,num_epochs))\n",
    "LS_2_squence_base = np.zeros((1,num_epochs))\n",
    "J_w_squence_base = np.zeros((1,num_epochs))\n",
    "LDA_squence_base = np.zeros((1,num_epochs))\n",
    "train_loss_squence = np.zeros((num_epochs,))\n",
    "train_accuracy_squence = np.zeros((num_epochs,))\n",
    "test_loss_squence = np.zeros((num_epochs,))\n",
    "test_accuracy_squence = np.zeros((num_epochs,))\n",
    "\n",
    "#eparability_labels = np.argmax(Separability_labels,axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[DBN] Layer 1 Pre-Training\n",
      "[DBN] Layer 2 Pre-Training\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/liwensheng/A线性可分性/Linear_Separability_numpy.py:31: RuntimeWarning: divide by zero encountered in float_scalars\n",
      "  J_w_value = J_w_up/J_w_down\n",
      "/home/liwensheng/A线性可分性/Linear_Separability_numpy.py:58: RuntimeWarning: invalid value encountered in true_divide\n",
      "  LDA_value = LDA_value_up/LDA_value_down\n",
      "/home/liwensheng/A线性可分性/Linear_Separability_numpy.py:98: RuntimeWarning: divide by zero encountered in true_divide\n",
      "  w = m/np.linalg.norm(m)\n",
      "/home/liwensheng/A线性可分性/Linear_Separability_numpy.py:98: RuntimeWarning: invalid value encountered in true_divide\n",
      "  w = m/np.linalg.norm(m)\n",
      "/home/liwensheng/A线性可分性/Linear_Separability_numpy.py:98: RuntimeWarning: divide by zero encountered in true_divide\n",
      "  w = m/np.linalg.norm(m)\n",
      "/home/liwensheng/A线性可分性/Linear_Separability_numpy.py:98: RuntimeWarning: invalid value encountered in true_divide\n",
      "  w = m/np.linalg.norm(m)\n",
      "/home/liwensheng/A线性可分性/Linear_Separability_numpy.py:98: RuntimeWarning: invalid value encountered in true_divide\n",
      "  w = m/np.linalg.norm(m)\n",
      "/home/liwensheng/A线性可分性/Linear_Separability_numpy.py:98: RuntimeWarning: invalid value encountered in true_divide\n",
      "  w = m/np.linalg.norm(m)\n",
      "/home/liwensheng/A线性可分性/Linear_Separability_numpy.py:98: RuntimeWarning: invalid value encountered in true_divide\n",
      "  w = m/np.linalg.norm(m)\n",
      "/home/liwensheng/A线性可分性/Linear_Separability_numpy.py:98: RuntimeWarning: invalid value encountered in true_divide\n",
      "  w = m/np.linalg.norm(m)\n",
      "/home/liwensheng/A线性可分性/Linear_Separability_numpy.py:98: RuntimeWarning: invalid value encountered in true_divide\n",
      "  w = m/np.linalg.norm(m)\n",
      "/home/liwensheng/A线性可分性/Linear_Separability_numpy.py:98: RuntimeWarning: invalid value encountered in true_divide\n",
      "  w = m/np.linalg.norm(m)\n",
      "/home/liwensheng/A线性可分性/Linear_Separability_numpy.py:98: RuntimeWarning: invalid value encountered in true_divide\n",
      "  w = m/np.linalg.norm(m)\n",
      "/home/liwensheng/A线性可分性/Linear_Separability_numpy.py:98: RuntimeWarning: invalid value encountered in true_divide\n",
      "  w = m/np.linalg.norm(m)\n",
      "/home/liwensheng/A线性可分性/Linear_Separability_numpy.py:98: RuntimeWarning: invalid value encountered in true_divide\n",
      "  w = m/np.linalg.norm(m)\n",
      "/home/liwensheng/A线性可分性/Linear_Separability_numpy.py:98: RuntimeWarning: invalid value encountered in true_divide\n",
      "  w = m/np.linalg.norm(m)\n",
      "/home/liwensheng/A线性可分性/Linear_Separability_numpy.py:98: RuntimeWarning: invalid value encountered in true_divide\n",
      "  w = m/np.linalg.norm(m)\n",
      "/home/liwensheng/A线性可分性/Linear_Separability_numpy.py:98: RuntimeWarning: invalid value encountered in true_divide\n",
      "  w = m/np.linalg.norm(m)\n",
      "/home/liwensheng/A线性可分性/Linear_Separability_numpy.py:98: RuntimeWarning: invalid value encountered in true_divide\n",
      "  w = m/np.linalg.norm(m)\n",
      "/home/liwensheng/A线性可分性/Linear_Separability_numpy.py:98: RuntimeWarning: invalid value encountered in true_divide\n",
      "  w = m/np.linalg.norm(m)\n",
      "/home/liwensheng/A线性可分性/Linear_Separability_numpy.py:98: RuntimeWarning: invalid value encountered in true_divide\n",
      "  w = m/np.linalg.norm(m)\n",
      "/home/liwensheng/A线性可分性/Linear_Separability_numpy.py:98: RuntimeWarning: invalid value encountered in true_divide\n",
      "  w = m/np.linalg.norm(m)\n",
      "/home/liwensheng/A线性可分性/Linear_Separability_numpy.py:98: RuntimeWarning: invalid value encountered in true_divide\n",
      "  w = m/np.linalg.norm(m)\n",
      "/home/liwensheng/A线性可分性/Linear_Separability_numpy.py:98: RuntimeWarning: invalid value encountered in true_divide\n",
      "  w = m/np.linalg.norm(m)\n",
      "/home/liwensheng/A线性可分性/Linear_Separability_numpy.py:98: RuntimeWarning: invalid value encountered in true_divide\n",
      "  w = m/np.linalg.norm(m)\n",
      "/home/liwensheng/A线性可分性/Linear_Separability_numpy.py:98: RuntimeWarning: invalid value encountered in true_divide\n",
      "  w = m/np.linalg.norm(m)\n",
      "/home/liwensheng/A线性可分性/Linear_Separability_numpy.py:98: RuntimeWarning: invalid value encountered in true_divide\n",
      "  w = m/np.linalg.norm(m)\n",
      "/home/liwensheng/A线性可分性/Linear_Separability_numpy.py:98: RuntimeWarning: invalid value encountered in true_divide\n",
      "  w = m/np.linalg.norm(m)\n",
      "/home/liwensheng/A线性可分性/Linear_Separability_numpy.py:98: RuntimeWarning: invalid value encountered in true_divide\n",
      "  w = m/np.linalg.norm(m)\n",
      "/home/liwensheng/A线性可分性/Linear_Separability_numpy.py:98: RuntimeWarning: invalid value encountered in true_divide\n",
      "  w = m/np.linalg.norm(m)\n",
      "/home/liwensheng/A线性可分性/Linear_Separability_numpy.py:98: RuntimeWarning: invalid value encountered in true_divide\n",
      "  w = m/np.linalg.norm(m)\n",
      "/home/liwensheng/A线性可分性/Linear_Separability_numpy.py:98: RuntimeWarning: invalid value encountered in true_divide\n",
      "  w = m/np.linalg.norm(m)\n",
      "/home/liwensheng/A线性可分性/Linear_Separability_numpy.py:98: RuntimeWarning: invalid value encountered in true_divide\n",
      "  w = m/np.linalg.norm(m)\n",
      "/home/liwensheng/A线性可分性/Linear_Separability_numpy.py:98: RuntimeWarning: invalid value encountered in true_divide\n",
      "  w = m/np.linalg.norm(m)\n",
      "/home/liwensheng/A线性可分性/Linear_Separability_numpy.py:98: RuntimeWarning: invalid value encountered in true_divide\n",
      "  w = m/np.linalg.norm(m)\n",
      "/home/liwensheng/A线性可分性/Linear_Separability_numpy.py:98: RuntimeWarning: invalid value encountered in true_divide\n",
      "  w = m/np.linalg.norm(m)\n",
      "/home/liwensheng/A线性可分性/Linear_Separability_numpy.py:98: RuntimeWarning: invalid value encountered in true_divide\n",
      "  w = m/np.linalg.norm(m)\n",
      "/home/liwensheng/A线性可分性/Linear_Separability_numpy.py:98: RuntimeWarning: invalid value encountered in true_divide\n",
      "  w = m/np.linalg.norm(m)\n",
      "/home/liwensheng/A线性可分性/Linear_Separability_numpy.py:98: RuntimeWarning: invalid value encountered in true_divide\n",
      "  w = m/np.linalg.norm(m)\n",
      "/home/liwensheng/A线性可分性/Linear_Separability_numpy.py:98: RuntimeWarning: invalid value encountered in true_divide\n",
      "  w = m/np.linalg.norm(m)\n",
      "/home/liwensheng/A线性可分性/Linear_Separability_numpy.py:98: RuntimeWarning: invalid value encountered in true_divide\n",
      "  w = m/np.linalg.norm(m)\n",
      "/home/liwensheng/A线性可分性/Linear_Separability_numpy.py:98: RuntimeWarning: invalid value encountered in true_divide\n",
      "  w = m/np.linalg.norm(m)\n",
      "/home/liwensheng/A线性可分性/Linear_Separability_numpy.py:98: RuntimeWarning: invalid value encountered in true_divide\n",
      "  w = m/np.linalg.norm(m)\n",
      "/home/liwensheng/A线性可分性/Linear_Separability_numpy.py:98: RuntimeWarning: invalid value encountered in true_divide\n",
      "  w = m/np.linalg.norm(m)\n",
      "/home/liwensheng/A线性可分性/Linear_Separability_numpy.py:98: RuntimeWarning: invalid value encountered in true_divide\n",
      "  w = m/np.linalg.norm(m)\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-9-7e7989d723ae>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     32\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     33\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mlayers_in_i\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 34\u001b[0;31m             \u001b[0mLS_1_squence\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mlayers_in_i\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0miter_l\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mLS_2_squence\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mlayers_in_i\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0miter_l\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mJ_w_squence\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mlayers_in_i\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0miter_l\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m            \u001b[0mLDA_squence\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mlayers_in_i\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0miter_l\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mW\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mlayers_in_i\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mSeparability_labels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     35\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     36\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/A线性可分性/Linear_Separability_numpy.py\u001b[0m in \u001b[0;36mW\u001b[0;34m(original_X, original_Y)\u001b[0m\n\u001b[1;32m     91\u001b[0m     \u001b[0;31m# M行数是r_1*r_2 列数是每个样本对应输出展开的维数\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     92\u001b[0m     \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mr_1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 93\u001b[0;31m         \u001b[0mM\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mindex_base\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mr_2\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnew_data_A\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0mnew_data_B\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     94\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     95\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/array_ops.py\u001b[0m in \u001b[0;36m_slice_helper\u001b[0;34m(tensor, slice_spec, var)\u001b[0m\n\u001b[1;32m    982\u001b[0m         \u001b[0mellipsis_mask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mellipsis_mask\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    983\u001b[0m         \u001b[0mvar\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mvar\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 984\u001b[0;31m         name=name)\n\u001b[0m\u001b[1;32m    985\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    986\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/array_ops.py\u001b[0m in \u001b[0;36mstrided_slice\u001b[0;34m(input_, begin, end, strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask, var, name)\u001b[0m\n\u001b[1;32m   1148\u001b[0m       \u001b[0mellipsis_mask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mellipsis_mask\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1149\u001b[0m       \u001b[0mnew_axis_mask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnew_axis_mask\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1150\u001b[0;31m       shrink_axis_mask=shrink_axis_mask)\n\u001b[0m\u001b[1;32m   1151\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1152\u001b[0m   \u001b[0mparent_name\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/gen_array_ops.py\u001b[0m in \u001b[0;36mstrided_slice\u001b[0;34m(input, begin, end, strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask, name)\u001b[0m\n\u001b[1;32m  10143\u001b[0m         \u001b[0mtld\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mop_callbacks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbegin\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mend\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstrides\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"begin_mask\"\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m  10144\u001b[0m         \u001b[0mbegin_mask\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"end_mask\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mend_mask\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"ellipsis_mask\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mellipsis_mask\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m> 10145\u001b[0;31m         \"new_axis_mask\", new_axis_mask, \"shrink_axis_mask\", shrink_axis_mask)\n\u001b[0m\u001b[1;32m  10146\u001b[0m       \u001b[0;32mreturn\u001b[0m \u001b[0m_result\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m  10147\u001b[0m     \u001b[0;32mexcept\u001b[0m \u001b[0m_core\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_FallbackException\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "visual_layer = train_x\n",
    "rbm_weights = []\n",
    "rbm_biases = []\n",
    "rbm_h_act = []      \n",
    "        \n",
    "for layers_i in range(len(layers)):\n",
    "    print(\"[DBN] Layer {} Pre-Training\".format(layers_i + 1))\n",
    "    \n",
    "    rbm = BernoulliRBM(n_components=layers[layers_i],\n",
    "#                        n_iter=1,\n",
    "                       learning_rate=rbm_lr[layers_i],\n",
    "                       verbose=True,)\n",
    "#                        batch_size=32)    \n",
    "    for iter_e in range(rbm_iters[layers_i]):\n",
    "#         rbm = BernoulliRBM(n_components=layers[layers_i],\n",
    "#                            n_iter=1,\n",
    "#                            learning_rate=rbm_lr[layers_i],\n",
    "#                            verbose=True,\n",
    "#                            batch_size=32)\n",
    "        rbm.partial_fit(visual_layer)\n",
    "        \n",
    "        rbm_weights = rbm.components_\n",
    "        rbm_biases = rbm.intercept_hidden_\n",
    "        \n",
    "        layer = model.get_layer('rbm_{}'.format(layers_i))\n",
    "        layer.set_weights(\n",
    "            [rbm_weights.transpose(), rbm_biases])\n",
    "        \n",
    "        iter_l = iter_e+sum(rbm_iters[:layers_i])\n",
    "        \n",
    "        x=activation_model(Separability_images)\n",
    "        \n",
    "        for layers_in_i in range(len(x)):\n",
    "            LS_1_squence[layers_in_i,iter_l],LS_2_squence[layers_in_i,iter_l],J_w_squence[layers_in_i,iter_l],\\\n",
    "            LDA_squence[layers_in_i,iter_l]=W(x[layers_in_i],Separability_labels.reshape(-1,1))  \n",
    "        \n",
    "        \n",
    "    visual_layer = rbm.transform(visual_layer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x[layers_in_i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Separability_labels.reshape(-1,1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# visual_layer = train_x\n",
    "# rbm_weights = []\n",
    "# rbm_biases = []\n",
    "# rbm_h_act = []      \n",
    "        \n",
    "# for layers_i in range(len(layers)):\n",
    "    \n",
    "#     print(\"[DBN] Layer {} Pre-Training\".format(layers_i + 1))\n",
    "#     rbm = BernoulliRBM(n_components=layers[layers_i],\n",
    "#                        n_iter=40,\n",
    "#                        learning_rate=rbm_lr[layers_i],\n",
    "#                        verbose=True,\n",
    "#                        batch_size=32)\n",
    "#     rbm.fit(visual_layer)\n",
    "\n",
    "    \n",
    "#     for iter_e in range(rbm_iters[layers_i]):\n",
    "\n",
    "        \n",
    "#         rbm_weights = rbm.components_\n",
    "#         rbm_biases = rbm.intercept_hidden_\n",
    "        \n",
    "#         layer = model.get_layer('rbm_{}'.format(layers_i))\n",
    "#         layer.set_weights(\n",
    "#             [rbm_weights.transpose(), rbm_biases])\n",
    "        \n",
    "#         iter_l = iter_e+sum(rbm_iters[:layers_i])\n",
    "        \n",
    "#         x=activation_model(Separability_images)\n",
    "        \n",
    "#         for layers_in_i in range(len(x)):\n",
    "#             LS_1_squence[layers_in_i,iter_l],LS_2_squence[layers_in_i,iter_l],J_w_squence[layers_in_i,iter_l],\\\n",
    "#             LDA_squence[layers_in_i,iter_l]=W(x[layers_in_i],Separability_labels.reshape(-1,1))  \n",
    "        \n",
    "        \n",
    "#     visual_layer = rbm.transform(visual_layer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "LS_1_squence_base[0,:],LS_2_squence_base[0,:],\\\n",
    "J_w_squence_base[0,:],\\\n",
    "LDA_squence_base[0,:]=W(tf.constant(Separability_images),Separability_labels.reshape(-1,1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "for iter_e in range(20):\n",
    "    model.fit(train_x,train_y,epochs=1,batch_size=128)\n",
    "    \n",
    "    iter_l = iter_e+sum(rbm_iters[:])\n",
    "    x=activation_model(Separability_images)\n",
    "        \n",
    "    for layers_in_i in range(len(x)):\n",
    "        LS_1_squence[layers_in_i,iter_l],LS_2_squence[layers_in_i,iter_l],J_w_squence[layers_in_i,iter_l],\\\n",
    "        LDA_squence[layers_in_i,iter_l]=W(x[layers_in_i],Separability_labels.reshape(-1,1))  \n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "layer_name_list = [\n",
    "    layer.name for layer in model.layers if 'input' not in layer.name\n",
    "]\n",
    "\n",
    "Separability_figure = plot_Separability_figure(layer_name_list,x_plot,LS_1_squence,LS_2_squence,J_w_squence,LDA_squence,LS_1_squence_base,LS_2_squence_base,J_w_squence_base,LDA_squence_base)\n",
    "net_figure = plot_net_figure(layer_name_list,x_plot,train_loss_squence,train_accuracy_squence,test_loss_squence,test_accuracy_squence)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
