{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ae78078e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-08T06:53:19.175652Z",
     "start_time": "2023-05-08T06:53:15.740076Z"
    }
   },
   "outputs": [],
   "source": [
    "#Simple CNN model for CIFAR-10 dataset\n",
    "import numpy as np\n",
    "\n",
    "# Simple CNN model for CIFAR-10\n",
    "import numpy as np\n",
    "import os\n",
    "from tensorflow.keras.datasets import cifar10, cifar100\n",
    "from tensorflow.keras.models import Sequential\n",
    "from tensorflow.keras.layers import Dense, Conv2D\n",
    "from tensorflow.keras.layers import Dropout\n",
    "from tensorflow.keras.layers import Flatten\n",
    "from tensorflow.keras.constraints import max_norm\n",
    "from tensorflow.keras.optimizers import SGD, Adam\n",
    "from tensorflow.keras.layers import Convolution2D\n",
    "from tensorflow.keras.layers import MaxPooling2D\n",
    "from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler\n",
    "from tensorflow.keras.callbacks import ReduceLROnPlateau\n",
    "from tensorflow.keras.callbacks import EarlyStopping\n",
    "from tensorflow.keras.regularizers import l2\n",
    "from tensorflow.keras import backend as K\n",
    "import tensorflow as tf\n",
    "from data_utils import *\n",
    "\n",
    "#from tensorflow.keras.utils import np_utils\n",
    "import matplotlib.pyplot as plt\n",
    "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
    "%matplotlib inline\n",
    "\n",
    "import skimage\n",
    "from skimage.util import img_as_ubyte\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "44f80b4b",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-08T06:53:20.371735Z",
     "start_time": "2023-05-08T06:53:19.203419Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "mnist = tf.keras.datasets.mnist\n",
    "(x_train_o, y_train_o), (x_test_o, y_test_o) = mnist.load_data()\n",
    "\n",
    "x_train_o = x_train_o.reshape(x_train_o.shape[0],-1)/255.0\n",
    "x_test_o = x_test_o.reshape(x_test_o.shape[0],-1)/255.0\n",
    "\n",
    "x_train_o = x_train_o.astype('float32')\n",
    "x_test_o = x_test_o.astype('float32')\n",
    "\n",
    "train_x,train_y = x_train_o.copy(),y_train_o.copy()\n",
    "test_x,test_y = x_test_o.copy(),y_test_o.copy()\n",
    "\n",
    "# all_index_train = np.where((y_train_o==0  | y_train_o==1))[0]\n",
    "# all_index_test = np.where((y_test_o==0 | y_test_o==1))[0]\n",
    "\n",
    "# train_x,train_y = x_train_o[all_index_train,:],y_train_o[all_index_train,:]\n",
    "# test_x,test_y = x_test_o[all_index_test,:],y_test_o[all_index_test,:]\n",
    "\n",
    "\n",
    "np.random.seed(1234)\n",
    "LS_index = np.random.randint(0,len(train_x),2000)\n",
    "\n",
    "Separability_images,Separability_labels = train_x[LS_index,:], train_y[LS_index,]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9fff6df5",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-08T06:53:20.387507Z",
     "start_time": "2023-05-08T06:53:20.378908Z"
    }
   },
   "outputs": [],
   "source": [
    "layers=[200,200,200,200]\n",
    "outputs=10\n",
    "rbm_iters=[40,40,40,40]\n",
    "rbm_lr=[0.01,0.01,0.01,0.01]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f5b6969d",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-08T06:53:20.818705Z",
     "start_time": "2023-05-08T06:53:20.394239Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn import linear_model, datasets\n",
    "from sklearn.metrics import classification_report\n",
    "from sklearn.neural_network import BernoulliRBM\n",
    "from tensorflow.keras.models import Sequential\n",
    "from tensorflow.keras.layers import Dense, Activation\n",
    "from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "426bbc10",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-08T06:53:22.919519Z",
     "start_time": "2023-05-08T06:53:20.822156Z"
    }
   },
   "outputs": [],
   "source": [
    "model = Sequential()\n",
    "for i in range(len(layers)):\n",
    "\n",
    "    if i == 0:\n",
    "        model.add(\n",
    "            Dense(layers[i],\n",
    "                  activation='relu',\n",
    "                  input_dim=train_x.shape[1],\n",
    "                  name='rbm_{}'.format(i)))\n",
    "    else:\n",
    "        model.add(\n",
    "            Dense(layers[i],\n",
    "                  activation='relu',\n",
    "                  name='rbm_{}'.format(i)))\n",
    "\n",
    "model.add(Dense(outputs, activation='softmax'))\n",
    "model.compile(optimizer='Adam',\n",
    "              loss='categorical_crossentropy',\n",
    "              metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "953bdb6c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-08T06:53:22.933687Z",
     "start_time": "2023-05-08T06:53:22.922912Z"
    }
   },
   "outputs": [],
   "source": [
    "layer_outputs = [\n",
    "    layer.output for layer in model.layers if 'input' not in layer.name\n",
    "]\n",
    "activation_model = tf.keras.models.Model(inputs=model.input,\n",
    "                                         outputs=layer_outputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "ae31fd84",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-08T06:53:23.688692Z",
     "start_time": "2023-05-08T06:53:22.936792Z"
    }
   },
   "outputs": [],
   "source": [
    "from Linear_Separability_numpy import *\n",
    "\n",
    "\n",
    "num_epochs = sum(rbm_iters)+20\n",
    "\n",
    "batch_ep = 1\n",
    "\n",
    "x_plot = np.arange(num_epochs)*batch_ep\n",
    "reserved_layers = 0\n",
    "\n",
    "\n",
    "\n",
    "# initialize record matrix\n",
    "x=activation_model.predict(Separability_images)\n",
    "LS_1_squence = np.zeros((len(x)-reserved_layers,num_epochs))\n",
    "LS_2_squence = np.zeros((len(x)-reserved_layers,num_epochs))\n",
    "J_w_squence = np.zeros((len(x)-reserved_layers,num_epochs))\n",
    "LDA_squence = np.zeros((len(x)-reserved_layers,num_epochs))\n",
    "LS_1_squence_base = np.zeros((1,num_epochs))\n",
    "LS_2_squence_base = np.zeros((1,num_epochs))\n",
    "J_w_squence_base = np.zeros((1,num_epochs))\n",
    "LDA_squence_base = np.zeros((1,num_epochs))\n",
    "train_loss_squence = np.zeros((num_epochs,))\n",
    "train_accuracy_squence = np.zeros((num_epochs,))\n",
    "test_loss_squence = np.zeros((num_epochs,))\n",
    "test_accuracy_squence = np.zeros((num_epochs,))\n",
    "\n",
    "#eparability_labels = np.argmax(Separability_labels,axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5792a5e3",
   "metadata": {
    "ExecuteTime": {
     "start_time": "2023-05-08T06:53:16.329Z"
    },
    "code_folding": [],
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[DBN] Layer 1 Pre-Training\n",
      "[DBN] Layer 2 Pre-Training\n",
      "[DBN] Layer 3 Pre-Training\n",
      "[DBN] Layer 4 Pre-Training\n"
     ]
    }
   ],
   "source": [
    "visual_layer = train_x\n",
    "rbm_weights = []\n",
    "rbm_biases = []\n",
    "rbm_h_act = []      \n",
    "        \n",
    "for layers_i in range(len(layers)):\n",
    "    print(\"[DBN] Layer {} Pre-Training\".format(layers_i + 1))\n",
    "    \n",
    "    for iter_e in range(rbm_iters[layers_i]):\n",
    "        rbm = BernoulliRBM(n_components=layers[layers_i],\n",
    "                           n_iter=1,\n",
    "                           learning_rate=rbm_lr[layers_i],\n",
    "                           verbose=True,\n",
    "                           batch_size=32)\n",
    "        rbm.partial_fit(visual_layer)\n",
    "        \n",
    "        rbm_weights = rbm.components_\n",
    "        rbm_biases = rbm.intercept_hidden_\n",
    "        \n",
    "        layer = model.get_layer('rbm_{}'.format(layers_i))\n",
    "        layer.set_weights(\n",
    "            [rbm_weights.transpose(), rbm_biases])\n",
    "        \n",
    "        iter_l = iter_e+sum(rbm_iters[:layers_i])\n",
    "        for layers_in_i in range(len(x)):\n",
    "            LS_1_squence[layers_in_i,iter_l],LS_2_squence[layers_in_i,iter_l],J_w_squence[layers_in_i,iter_l],\\\n",
    "            LDA_squence[layers_in_i,iter_l]=W(x[layers_in_i],Separability_labels.reshape(-1,1))  \n",
    "        \n",
    "        \n",
    "    visual_layer = rbm.transform(visual_layer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73aa49c1",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d74df08",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a761236",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tensorflow-gpu",
   "language": "python",
   "name": "test"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.0"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
