{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn import linear_model, datasets\n",
    "from sklearn.metrics import classification_report\n",
    "from sklearn.neural_network import BernoulliRBM\n",
    "from tensorflow.keras.models import Sequential\n",
    "from tensorflow.keras.layers import Dense, Activation\n",
    "from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard\n",
    "import os\n",
    "import json\n",
    "import pickle\n",
    "\n",
    "\n",
    "class DBN():\n",
    "\n",
    "    def __init__(self,\n",
    "                 train_data,\n",
    "                 targets,\n",
    "                 layers,\n",
    "                 outputs,\n",
    "                 rbm_lr,\n",
    "                 rbm_iters,\n",
    "                 rbm_dir=None,\n",
    "                 test_data=None,\n",
    "                 test_targets=None,\n",
    "                 epochs=25,\n",
    "                 fine_tune_batch_size=32,\n",
    "                 outdir=\"tmp/\",\n",
    "                 logdir=\"logs/\"):\n",
    "\n",
    "        self.hidden_sizes = layers\n",
    "        self.outputs = outputs\n",
    "        self.targets = targets\n",
    "        self.data = train_data\n",
    "\n",
    "        if test_data is None:\n",
    "            self.validate = False\n",
    "        else:\n",
    "            self.validate = True\n",
    "\n",
    "        self.valid_data = test_data\n",
    "        self.valid_labels = test_targets\n",
    "\n",
    "        self.rbm_learning_rate = rbm_lr\n",
    "        self.rbm_iters = rbm_iters\n",
    "\n",
    "        self.epochs = epochs\n",
    "        self.nn_batch_size = fine_tune_batch_size\n",
    "\n",
    "        self.rbm_weights = []\n",
    "        self.rbm_biases = []\n",
    "        self.rbm_h_act = []\n",
    "\n",
    "        self.model = None\n",
    "        self.history = None\n",
    "\n",
    "        if not os.path.exists(outdir):\n",
    "            os.makedirs(outdir)\n",
    "        if not os.path.exists(logdir):\n",
    "            os.makedirs(logdir)\n",
    "\n",
    "        if outdir[-1] != '/':\n",
    "            outdir = outdir + '/'\n",
    "\n",
    "        self.outdir = outdir\n",
    "        self.logdir = logdir\n",
    "\n",
    "    def pretrain(self, save=True):\n",
    "\n",
    "        visual_layer = self.data\n",
    "\n",
    "        for i in range(len(self.hidden_sizes)):\n",
    "            print(\"[DBN] Layer {} Pre-Training\".format(i + 1))\n",
    "\n",
    "            rbm = BernoulliRBM(n_components=self.hidden_sizes[i],\n",
    "                               n_iter=self.rbm_iters[i],\n",
    "                               learning_rate=self.rbm_learning_rate[i],\n",
    "                               verbose=True,\n",
    "                               batch_size=32)\n",
    "            rbm.fit(visual_layer)\n",
    "            self.rbm_weights.append(rbm.components_)\n",
    "            self.rbm_biases.append(rbm.intercept_hidden_)\n",
    "            self.rbm_h_act.append(rbm.transform(visual_layer))\n",
    "\n",
    "            visual_layer = self.rbm_h_act[-1]\n",
    "\n",
    "        if save:\n",
    "            with open(self.outdir + \"rbm_weights.p\", 'wb') as f:\n",
    "                pickle.dump(self.rbm_weights, f)\n",
    "\n",
    "            with open(self.outdir + \"rbm_biases.p\", 'wb') as f:\n",
    "                pickle.dump(self.rbm_biases, f)\n",
    "\n",
    "            with open(self.outdir + \"rbm_hidden.p\", 'wb') as f:\n",
    "                pickle.dump(self.rbm_h_act, f)\n",
    "\n",
    "    def finetune(self):\n",
    "        model = Sequential()\n",
    "        for i in range(len(self.hidden_sizes)):\n",
    "\n",
    "            if i == 0:\n",
    "                model.add(\n",
    "                    Dense(self.hidden_sizes[i],\n",
    "                          activation='relu',\n",
    "                          input_dim=self.data.shape[1],\n",
    "                          name='rbm_{}'.format(i)))\n",
    "            else:\n",
    "                model.add(\n",
    "                    Dense(self.hidden_sizes[i],\n",
    "                          activation='relu',\n",
    "                          name='rbm_{}'.format(i)))\n",
    "\n",
    "        model.add(Dense(self.outputs, activation='softmax'))\n",
    "        model.compile(optimizer='Adam',\n",
    "                      loss='categorical_crossentropy',\n",
    "                      metrics=['accuracy'])\n",
    "\n",
    "        for i in range(len(self.hidden_sizes)):\n",
    "            layer = model.get_layer('rbm_{}'.format(i))\n",
    "            layer.set_weights(\n",
    "                [self.rbm_weights[i].transpose(), self.rbm_biases[i]])\n",
    "\n",
    "        checkpointer = ModelCheckpoint(filepath=self.outdir +\n",
    "                                       \"dbn_weights.hdf5\",\n",
    "                                       verbose=1,\n",
    "                                       save_best_only=True)\n",
    "        tensorboard = TensorBoard(log_dir=self.logdir)\n",
    "\n",
    "        if self.validate:\n",
    "            self.history = model.fit(trainx,\n",
    "                                     trainy,\n",
    "                                     epochs=self.epochs,\n",
    "                                     batch_size=self.nn_batch_size,\n",
    "                                     validation_data=(self.valid_data,\n",
    "                                                      self.valid_labels),\n",
    "                                     callbacks=[checkpointer, tensorboard])\n",
    "        else:\n",
    "            self.history = model.fit(trainx,\n",
    "                                     trainy,\n",
    "                                     epochs=self.epochs,\n",
    "                                     batch_size=self.nn_batch_size,\n",
    "                                     callbacks=[checkpointer, tensorboard])\n",
    "        self.model = model\n",
    "\n",
    "    def report(self, data, labels):\n",
    "        print(\n",
    "            classification_report(np.argmax(labels, axis=1),\n",
    "                                  np.argmax(self.model.predict(data), axis=1)))\n",
    "\n",
    "    def save_model(self, filename):\n",
    "\n",
    "        if self.model is None:\n",
    "            raise ValueError(\"Run finetune() first\")\n",
    "\n",
    "        with open(self.outdir + filename, mode='w',\n",
    "                  encoding='utf-8') as outfile:\n",
    "\n",
    "            data = {\n",
    "                \"model_config\": self.model.get_config(),\n",
    "                \"loss_acc\": self.history.history\n",
    "            }\n",
    "            json.dump(data, outfile, indent=2)\n",
    "\n",
    "    def load_rbm(self):\n",
    "        try:\n",
    "            self.rbm_weights = pickle.load(self.rbm_dir + \"rbm_weights.p\")\n",
    "            self.rbm_biases = pickle.load(self.rbm_dir + \"rbm_biases.p\")\n",
    "            self.rbm_h_act = pickle.load(self.rbm_dir + \"rbm_hidden.p\")\n",
    "        except:\n",
    "            print(\"No such file or directory.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[DBN] Layer 1 Pre-Training\n",
      "[BernoulliRBM] Iteration 1, pseudo-likelihood = -126.11, time = 10.29s\n",
      "[BernoulliRBM] Iteration 2, pseudo-likelihood = -107.97, time = 11.63s\n",
      "[BernoulliRBM] Iteration 3, pseudo-likelihood = -98.94, time = 11.81s\n",
      "[BernoulliRBM] Iteration 4, pseudo-likelihood = -92.96, time = 11.47s\n",
      "[BernoulliRBM] Iteration 5, pseudo-likelihood = -89.06, time = 11.92s\n",
      "[BernoulliRBM] Iteration 6, pseudo-likelihood = -85.66, time = 11.24s\n",
      "[BernoulliRBM] Iteration 7, pseudo-likelihood = -82.86, time = 11.13s\n",
      "[BernoulliRBM] Iteration 8, pseudo-likelihood = -82.68, time = 11.13s\n",
      "[BernoulliRBM] Iteration 9, pseudo-likelihood = -80.44, time = 10.97s\n",
      "[BernoulliRBM] Iteration 10, pseudo-likelihood = -78.81, time = 10.83s\n",
      "[BernoulliRBM] Iteration 11, pseudo-likelihood = -78.46, time = 11.44s\n",
      "[BernoulliRBM] Iteration 12, pseudo-likelihood = -77.45, time = 10.79s\n",
      "[BernoulliRBM] Iteration 13, pseudo-likelihood = -75.97, time = 10.57s\n",
      "[BernoulliRBM] Iteration 14, pseudo-likelihood = -74.79, time = 10.83s\n",
      "[BernoulliRBM] Iteration 15, pseudo-likelihood = -74.43, time = 11.04s\n",
      "[BernoulliRBM] Iteration 16, pseudo-likelihood = -74.08, time = 10.90s\n",
      "[BernoulliRBM] Iteration 17, pseudo-likelihood = -71.72, time = 11.27s\n",
      "[BernoulliRBM] Iteration 18, pseudo-likelihood = -71.17, time = 10.71s\n",
      "[BernoulliRBM] Iteration 19, pseudo-likelihood = -70.72, time = 10.85s\n",
      "[BernoulliRBM] Iteration 20, pseudo-likelihood = -72.38, time = 10.42s\n",
      "[BernoulliRBM] Iteration 21, pseudo-likelihood = -70.26, time = 10.44s\n",
      "[BernoulliRBM] Iteration 22, pseudo-likelihood = -70.17, time = 10.83s\n",
      "[BernoulliRBM] Iteration 23, pseudo-likelihood = -70.10, time = 10.87s\n",
      "[BernoulliRBM] Iteration 24, pseudo-likelihood = -69.71, time = 10.38s\n",
      "[BernoulliRBM] Iteration 25, pseudo-likelihood = -69.50, time = 10.50s\n",
      "[BernoulliRBM] Iteration 26, pseudo-likelihood = -69.21, time = 10.67s\n",
      "[BernoulliRBM] Iteration 27, pseudo-likelihood = -67.53, time = 10.74s\n",
      "[BernoulliRBM] Iteration 28, pseudo-likelihood = -70.48, time = 10.42s\n",
      "[BernoulliRBM] Iteration 29, pseudo-likelihood = -68.59, time = 10.88s\n",
      "[BernoulliRBM] Iteration 30, pseudo-likelihood = -69.00, time = 10.61s\n",
      "[BernoulliRBM] Iteration 31, pseudo-likelihood = -69.84, time = 10.69s\n",
      "[BernoulliRBM] Iteration 32, pseudo-likelihood = -67.21, time = 10.53s\n",
      "[BernoulliRBM] Iteration 33, pseudo-likelihood = -69.44, time = 10.52s\n",
      "[BernoulliRBM] Iteration 34, pseudo-likelihood = -68.89, time = 10.48s\n",
      "[BernoulliRBM] Iteration 35, pseudo-likelihood = -67.80, time = 10.69s\n",
      "[BernoulliRBM] Iteration 36, pseudo-likelihood = -67.63, time = 10.60s\n",
      "[BernoulliRBM] Iteration 37, pseudo-likelihood = -68.03, time = 10.89s\n",
      "[BernoulliRBM] Iteration 38, pseudo-likelihood = -67.17, time = 11.10s\n",
      "[BernoulliRBM] Iteration 39, pseudo-likelihood = -67.40, time = 11.04s\n",
      "[BernoulliRBM] Iteration 40, pseudo-likelihood = -65.64, time = 11.29s\n",
      "[DBN] Layer 2 Pre-Training\n",
      "[BernoulliRBM] Iteration 1, pseudo-likelihood = -78.02, time = 3.40s\n",
      "[BernoulliRBM] Iteration 2, pseudo-likelihood = -66.12, time = 5.21s\n",
      "[BernoulliRBM] Iteration 3, pseudo-likelihood = -60.25, time = 5.48s\n",
      "[BernoulliRBM] Iteration 4, pseudo-likelihood = -57.01, time = 5.23s\n",
      "[BernoulliRBM] Iteration 5, pseudo-likelihood = -54.56, time = 5.02s\n",
      "[BernoulliRBM] Iteration 6, pseudo-likelihood = -52.37, time = 5.08s\n",
      "[BernoulliRBM] Iteration 7, pseudo-likelihood = -50.25, time = 5.11s\n",
      "[BernoulliRBM] Iteration 8, pseudo-likelihood = -49.41, time = 4.97s\n",
      "[BernoulliRBM] Iteration 9, pseudo-likelihood = -47.50, time = 5.15s\n",
      "[BernoulliRBM] Iteration 10, pseudo-likelihood = -47.04, time = 4.78s\n",
      "[BernoulliRBM] Iteration 11, pseudo-likelihood = -46.75, time = 5.04s\n",
      "[BernoulliRBM] Iteration 12, pseudo-likelihood = -45.68, time = 4.83s\n",
      "[BernoulliRBM] Iteration 13, pseudo-likelihood = -44.81, time = 4.98s\n",
      "[BernoulliRBM] Iteration 14, pseudo-likelihood = -44.45, time = 4.59s\n",
      "[BernoulliRBM] Iteration 15, pseudo-likelihood = -45.01, time = 4.93s\n",
      "[BernoulliRBM] Iteration 16, pseudo-likelihood = -44.13, time = 5.02s\n",
      "[BernoulliRBM] Iteration 17, pseudo-likelihood = -43.20, time = 4.84s\n",
      "[BernoulliRBM] Iteration 18, pseudo-likelihood = -42.55, time = 4.54s\n",
      "[BernoulliRBM] Iteration 19, pseudo-likelihood = -42.62, time = 4.67s\n",
      "[BernoulliRBM] Iteration 20, pseudo-likelihood = -42.15, time = 4.54s\n",
      "[BernoulliRBM] Iteration 21, pseudo-likelihood = -42.22, time = 4.60s\n",
      "[BernoulliRBM] Iteration 22, pseudo-likelihood = -41.81, time = 4.71s\n",
      "[BernoulliRBM] Iteration 23, pseudo-likelihood = -41.90, time = 4.73s\n",
      "[BernoulliRBM] Iteration 24, pseudo-likelihood = -41.93, time = 5.07s\n",
      "[BernoulliRBM] Iteration 25, pseudo-likelihood = -41.61, time = 5.05s\n",
      "[BernoulliRBM] Iteration 26, pseudo-likelihood = -41.20, time = 4.61s\n",
      "[BernoulliRBM] Iteration 27, pseudo-likelihood = -40.82, time = 4.50s\n",
      "[BernoulliRBM] Iteration 28, pseudo-likelihood = -40.07, time = 4.50s\n",
      "[BernoulliRBM] Iteration 29, pseudo-likelihood = -40.71, time = 4.46s\n",
      "[BernoulliRBM] Iteration 30, pseudo-likelihood = -40.37, time = 4.54s\n",
      "[BernoulliRBM] Iteration 31, pseudo-likelihood = -41.18, time = 4.42s\n",
      "[BernoulliRBM] Iteration 32, pseudo-likelihood = -40.68, time = 4.79s\n",
      "[BernoulliRBM] Iteration 33, pseudo-likelihood = -40.14, time = 4.46s\n",
      "[BernoulliRBM] Iteration 34, pseudo-likelihood = -40.40, time = 5.00s\n",
      "[BernoulliRBM] Iteration 35, pseudo-likelihood = -39.59, time = 4.96s\n",
      "[BernoulliRBM] Iteration 36, pseudo-likelihood = -40.44, time = 4.96s\n",
      "[BernoulliRBM] Iteration 37, pseudo-likelihood = -40.10, time = 4.52s\n",
      "[BernoulliRBM] Iteration 38, pseudo-likelihood = -39.61, time = 4.97s\n",
      "[BernoulliRBM] Iteration 39, pseudo-likelihood = -39.95, time = 4.71s\n",
      "[BernoulliRBM] Iteration 40, pseudo-likelihood = -39.38, time = 4.48s\n",
      "[DBN] Layer 3 Pre-Training\n",
      "[BernoulliRBM] Iteration 1, pseudo-likelihood = -59.01, time = 3.62s\n",
      "[BernoulliRBM] Iteration 2, pseudo-likelihood = -49.32, time = 4.58s\n",
      "[BernoulliRBM] Iteration 3, pseudo-likelihood = -44.43, time = 4.84s\n",
      "[BernoulliRBM] Iteration 4, pseudo-likelihood = -41.84, time = 4.85s\n",
      "[BernoulliRBM] Iteration 5, pseudo-likelihood = -39.32, time = 4.69s\n",
      "[BernoulliRBM] Iteration 6, pseudo-likelihood = -38.71, time = 4.57s\n",
      "[BernoulliRBM] Iteration 7, pseudo-likelihood = -37.85, time = 4.51s\n",
      "[BernoulliRBM] Iteration 8, pseudo-likelihood = -36.59, time = 4.58s\n",
      "[BernoulliRBM] Iteration 9, pseudo-likelihood = -35.34, time = 5.46s\n",
      "[BernoulliRBM] Iteration 10, pseudo-likelihood = -34.53, time = 4.51s\n",
      "[BernoulliRBM] Iteration 11, pseudo-likelihood = -34.41, time = 4.90s\n",
      "[BernoulliRBM] Iteration 12, pseudo-likelihood = -34.15, time = 4.46s\n",
      "[BernoulliRBM] Iteration 13, pseudo-likelihood = -34.16, time = 4.66s\n",
      "[BernoulliRBM] Iteration 14, pseudo-likelihood = -33.57, time = 4.55s\n",
      "[BernoulliRBM] Iteration 15, pseudo-likelihood = -33.08, time = 4.51s\n",
      "[BernoulliRBM] Iteration 16, pseudo-likelihood = -32.92, time = 4.68s\n",
      "[BernoulliRBM] Iteration 17, pseudo-likelihood = -32.76, time = 4.50s\n",
      "[BernoulliRBM] Iteration 18, pseudo-likelihood = -32.19, time = 5.11s\n",
      "[BernoulliRBM] Iteration 19, pseudo-likelihood = -32.23, time = 5.08s\n",
      "[BernoulliRBM] Iteration 20, pseudo-likelihood = -32.98, time = 4.55s\n",
      "[BernoulliRBM] Iteration 21, pseudo-likelihood = -32.21, time = 4.54s\n",
      "[BernoulliRBM] Iteration 22, pseudo-likelihood = -31.88, time = 4.58s\n",
      "[BernoulliRBM] Iteration 23, pseudo-likelihood = -31.65, time = 4.60s\n",
      "[BernoulliRBM] Iteration 24, pseudo-likelihood = -31.78, time = 4.55s\n",
      "[BernoulliRBM] Iteration 25, pseudo-likelihood = -31.42, time = 4.92s\n",
      "[BernoulliRBM] Iteration 26, pseudo-likelihood = -31.63, time = 4.60s\n",
      "[BernoulliRBM] Iteration 27, pseudo-likelihood = -31.35, time = 4.53s\n",
      "[BernoulliRBM] Iteration 28, pseudo-likelihood = -32.10, time = 4.72s\n",
      "[BernoulliRBM] Iteration 29, pseudo-likelihood = -31.31, time = 4.51s\n",
      "[BernoulliRBM] Iteration 30, pseudo-likelihood = -31.07, time = 4.62s\n",
      "[BernoulliRBM] Iteration 31, pseudo-likelihood = -31.11, time = 4.70s\n",
      "[BernoulliRBM] Iteration 32, pseudo-likelihood = -30.80, time = 5.20s\n",
      "[BernoulliRBM] Iteration 33, pseudo-likelihood = -30.97, time = 5.15s\n",
      "[BernoulliRBM] Iteration 34, pseudo-likelihood = -30.44, time = 5.11s\n",
      "[BernoulliRBM] Iteration 35, pseudo-likelihood = -31.73, time = 4.75s\n",
      "[BernoulliRBM] Iteration 36, pseudo-likelihood = -30.87, time = 4.56s\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[BernoulliRBM] Iteration 37, pseudo-likelihood = -31.06, time = 4.49s\n",
      "[BernoulliRBM] Iteration 38, pseudo-likelihood = -31.11, time = 4.58s\n",
      "[BernoulliRBM] Iteration 39, pseudo-likelihood = -30.59, time = 4.90s\n",
      "[BernoulliRBM] Iteration 40, pseudo-likelihood = -30.72, time = 5.08s\n",
      "[DBN] Layer 4 Pre-Training\n",
      "[BernoulliRBM] Iteration 1, pseudo-likelihood = -54.25, time = 3.40s\n",
      "[BernoulliRBM] Iteration 2, pseudo-likelihood = -45.78, time = 4.72s\n",
      "[BernoulliRBM] Iteration 3, pseudo-likelihood = -40.34, time = 4.79s\n",
      "[BernoulliRBM] Iteration 4, pseudo-likelihood = -38.30, time = 4.63s\n",
      "[BernoulliRBM] Iteration 5, pseudo-likelihood = -36.93, time = 4.74s\n",
      "[BernoulliRBM] Iteration 6, pseudo-likelihood = -35.45, time = 5.20s\n",
      "[BernoulliRBM] Iteration 7, pseudo-likelihood = -35.08, time = 5.25s\n",
      "[BernoulliRBM] Iteration 8, pseudo-likelihood = -34.69, time = 4.79s\n",
      "[BernoulliRBM] Iteration 9, pseudo-likelihood = -34.48, time = 5.44s\n",
      "[BernoulliRBM] Iteration 10, pseudo-likelihood = -34.23, time = 5.43s\n",
      "[BernoulliRBM] Iteration 11, pseudo-likelihood = -33.28, time = 5.39s\n",
      "[BernoulliRBM] Iteration 12, pseudo-likelihood = -32.96, time = 4.99s\n",
      "[BernoulliRBM] Iteration 13, pseudo-likelihood = -33.26, time = 4.65s\n",
      "[BernoulliRBM] Iteration 14, pseudo-likelihood = -32.32, time = 4.73s\n",
      "[BernoulliRBM] Iteration 15, pseudo-likelihood = -32.56, time = 5.26s\n",
      "[BernoulliRBM] Iteration 16, pseudo-likelihood = -32.65, time = 4.63s\n",
      "[BernoulliRBM] Iteration 17, pseudo-likelihood = -32.15, time = 5.06s\n",
      "[BernoulliRBM] Iteration 18, pseudo-likelihood = -31.98, time = 5.11s\n",
      "[BernoulliRBM] Iteration 19, pseudo-likelihood = -31.17, time = 5.35s\n",
      "[BernoulliRBM] Iteration 20, pseudo-likelihood = -31.62, time = 5.38s\n",
      "[BernoulliRBM] Iteration 21, pseudo-likelihood = -31.74, time = 4.77s\n",
      "[BernoulliRBM] Iteration 22, pseudo-likelihood = -31.77, time = 4.91s\n",
      "[BernoulliRBM] Iteration 23, pseudo-likelihood = -31.89, time = 5.47s\n",
      "[BernoulliRBM] Iteration 24, pseudo-likelihood = -31.40, time = 4.71s\n",
      "[BernoulliRBM] Iteration 25, pseudo-likelihood = -31.67, time = 4.75s\n",
      "[BernoulliRBM] Iteration 26, pseudo-likelihood = -31.43, time = 4.83s\n",
      "[BernoulliRBM] Iteration 27, pseudo-likelihood = -32.00, time = 4.83s\n",
      "[BernoulliRBM] Iteration 28, pseudo-likelihood = -30.63, time = 4.70s\n",
      "[BernoulliRBM] Iteration 29, pseudo-likelihood = -31.03, time = 4.98s\n",
      "[BernoulliRBM] Iteration 30, pseudo-likelihood = -31.27, time = 4.84s\n",
      "[BernoulliRBM] Iteration 31, pseudo-likelihood = -31.32, time = 4.80s\n",
      "[BernoulliRBM] Iteration 32, pseudo-likelihood = -31.14, time = 5.30s\n",
      "[BernoulliRBM] Iteration 33, pseudo-likelihood = -30.82, time = 4.91s\n",
      "[BernoulliRBM] Iteration 34, pseudo-likelihood = -30.16, time = 4.92s\n",
      "[BernoulliRBM] Iteration 35, pseudo-likelihood = -31.01, time = 4.83s\n",
      "[BernoulliRBM] Iteration 36, pseudo-likelihood = -30.49, time = 4.90s\n",
      "[BernoulliRBM] Iteration 37, pseudo-likelihood = -30.72, time = 4.98s\n",
      "[BernoulliRBM] Iteration 38, pseudo-likelihood = -29.85, time = 4.85s\n",
      "[BernoulliRBM] Iteration 39, pseudo-likelihood = -30.30, time = 4.92s\n",
      "[BernoulliRBM] Iteration 40, pseudo-likelihood = -31.28, time = 4.87s\n",
      "Epoch 1/25\n",
      "1861/1875 [============================>.] - ETA: 0s - loss: 37.0327 - accuracy: 0.8791WARNING:tensorflow:Can save best model only with val_loss available, skipping.\n",
      "1875/1875 [==============================] - 5s 3ms/step - loss: 36.7986 - accuracy: 0.8794\n",
      "Epoch 2/25\n",
      "1871/1875 [============================>.] - ETA: 0s - loss: 3.7199 - accuracy: 0.9435WARNING:tensorflow:Can save best model only with val_loss available, skipping.\n",
      "1875/1875 [==============================] - 5s 3ms/step - loss: 3.7171 - accuracy: 0.9435\n",
      "Epoch 3/25\n",
      "1861/1875 [============================>.] - ETA: 0s - loss: 1.5806 - accuracy: 0.9559WARNING:tensorflow:Can save best model only with val_loss available, skipping.\n",
      "1875/1875 [==============================] - 5s 3ms/step - loss: 1.5795 - accuracy: 0.9558\n",
      "Epoch 4/25\n",
      "1868/1875 [============================>.] - ETA: 0s - loss: 0.8254 - accuracy: 0.9640WARNING:tensorflow:Can save best model only with val_loss available, skipping.\n",
      "1875/1875 [==============================] - 5s 3ms/step - loss: 0.8251 - accuracy: 0.9640\n",
      "Epoch 5/25\n",
      "1861/1875 [============================>.] - ETA: 0s - loss: 0.4959 - accuracy: 0.9680WARNING:tensorflow:Can save best model only with val_loss available, skipping.\n",
      "1875/1875 [==============================] - 5s 3ms/step - loss: 0.4961 - accuracy: 0.9680\n",
      "Epoch 6/25\n",
      "1862/1875 [============================>.] - ETA: 0s - loss: 0.3436 - accuracy: 0.9708WARNING:tensorflow:Can save best model only with val_loss available, skipping.\n",
      "1875/1875 [==============================] - 5s 3ms/step - loss: 0.3433 - accuracy: 0.9708\n",
      "Epoch 7/25\n",
      "1869/1875 [============================>.] - ETA: 0s - loss: 0.2127 - accuracy: 0.9752WARNING:tensorflow:Can save best model only with val_loss available, skipping.\n",
      "1875/1875 [==============================] - 5s 3ms/step - loss: 0.2131 - accuracy: 0.9751\n",
      "Epoch 8/25\n",
      "1870/1875 [============================>.] - ETA: 0s - loss: 0.1816 - accuracy: 0.9743WARNING:tensorflow:Can save best model only with val_loss available, skipping.\n",
      "1875/1875 [==============================] - 5s 3ms/step - loss: 0.1814 - accuracy: 0.9743\n",
      "Epoch 9/25\n",
      "1872/1875 [============================>.] - ETA: 0s - loss: 0.1295 - accuracy: 0.9791WARNING:tensorflow:Can save best model only with val_loss available, skipping.\n",
      "1875/1875 [==============================] - 5s 3ms/step - loss: 0.1296 - accuracy: 0.9791\n",
      "Epoch 10/25\n",
      "1865/1875 [============================>.] - ETA: 0s - loss: 0.1125 - accuracy: 0.9807WARNING:tensorflow:Can save best model only with val_loss available, skipping.\n",
      "1875/1875 [==============================] - 5s 3ms/step - loss: 0.1127 - accuracy: 0.9807\n",
      "Epoch 11/25\n",
      "1868/1875 [============================>.] - ETA: 0s - loss: 0.0891 - accuracy: 0.9818WARNING:tensorflow:Can save best model only with val_loss available, skipping.\n",
      "1875/1875 [==============================] - 5s 3ms/step - loss: 0.0896 - accuracy: 0.9818\n",
      "Epoch 12/25\n",
      "1872/1875 [============================>.] - ETA: 0s - loss: 0.0680 - accuracy: 0.9849WARNING:tensorflow:Can save best model only with val_loss available, skipping.\n",
      "1875/1875 [==============================] - 5s 3ms/step - loss: 0.0680 - accuracy: 0.9849\n",
      "Epoch 13/25\n",
      "1865/1875 [============================>.] - ETA: 0s - loss: 0.0675 - accuracy: 0.9857WARNING:tensorflow:Can save best model only with val_loss available, skipping.\n",
      "1875/1875 [==============================] - 5s 3ms/step - loss: 0.0678 - accuracy: 0.9857\n",
      "Epoch 14/25\n",
      "1875/1875 [==============================] - ETA: 0s - loss: 0.0558 - accuracy: 0.9873WARNING:tensorflow:Can save best model only with val_loss available, skipping.\n",
      "1875/1875 [==============================] - 6s 3ms/step - loss: 0.0558 - accuracy: 0.9873\n",
      "Epoch 15/25\n",
      "1870/1875 [============================>.] - ETA: 0s - loss: 0.0566 - accuracy: 0.9872WARNING:tensorflow:Can save best model only with val_loss available, skipping.\n",
      "1875/1875 [==============================] - 6s 3ms/step - loss: 0.0569 - accuracy: 0.9872\n",
      "Epoch 16/25\n",
      "1871/1875 [============================>.] - ETA: 0s - loss: 0.0438 - accuracy: 0.9893WARNING:tensorflow:Can save best model only with val_loss available, skipping.\n",
      "1875/1875 [==============================] - 6s 3ms/step - loss: 0.0437 - accuracy: 0.9893\n",
      "Epoch 17/25\n",
      "1874/1875 [============================>.] - ETA: 0s - loss: 0.0468 - accuracy: 0.9896WARNING:tensorflow:Can save best model only with val_loss available, skipping.\n",
      "1875/1875 [==============================] - 6s 3ms/step - loss: 0.0468 - accuracy: 0.9896\n",
      "Epoch 18/25\n",
      "1862/1875 [============================>.] - ETA: 0s - loss: 0.0358 - accuracy: 0.9911WARNING:tensorflow:Can save best model only with val_loss available, skipping.\n",
      "1875/1875 [==============================] - 6s 3ms/step - loss: 0.0361 - accuracy: 0.9911\n",
      "Epoch 19/25\n",
      "1859/1875 [============================>.] - ETA: 0s - loss: 0.0376 - accuracy: 0.9913WARNING:tensorflow:Can save best model only with val_loss available, skipping.\n",
      "1875/1875 [==============================] - 5s 3ms/step - loss: 0.0376 - accuracy: 0.9913\n",
      "Epoch 20/25\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1869/1875 [============================>.] - ETA: 0s - loss: 0.0389 - accuracy: 0.9911WARNING:tensorflow:Can save best model only with val_loss available, skipping.\n",
      "1875/1875 [==============================] - 6s 3ms/step - loss: 0.0390 - accuracy: 0.9911\n",
      "Epoch 21/25\n",
      "1864/1875 [============================>.] - ETA: 0s - loss: 0.0335 - accuracy: 0.9925WARNING:tensorflow:Can save best model only with val_loss available, skipping.\n",
      "1875/1875 [==============================] - 6s 3ms/step - loss: 0.0334 - accuracy: 0.9925\n",
      "Epoch 22/25\n",
      "1857/1875 [============================>.] - ETA: 0s - loss: 0.0306 - accuracy: 0.9931WARNING:tensorflow:Can save best model only with val_loss available, skipping.\n",
      "1875/1875 [==============================] - 6s 3ms/step - loss: 0.0306 - accuracy: 0.9930\n",
      "Epoch 23/25\n",
      "1872/1875 [============================>.] - ETA: 0s - loss: 0.0298 - accuracy: 0.9933WARNING:tensorflow:Can save best model only with val_loss available, skipping.\n",
      "1875/1875 [==============================] - 5s 3ms/step - loss: 0.0297 - accuracy: 0.9933\n",
      "Epoch 24/25\n",
      "1873/1875 [============================>.] - ETA: 0s - loss: 0.0282 - accuracy: 0.9937WARNING:tensorflow:Can save best model only with val_loss available, skipping.\n",
      "1875/1875 [==============================] - 5s 3ms/step - loss: 0.0282 - accuracy: 0.9937\n",
      "Epoch 25/25\n",
      "1858/1875 [============================>.] - ETA: 0s - loss: 0.0323 - accuracy: 0.9927WARNING:tensorflow:Can save best model only with val_loss available, skipping.\n",
      "1875/1875 [==============================] - 5s 3ms/step - loss: 0.0321 - accuracy: 0.9927\n"
     ]
    },
    {
     "ename": "TypeError",
     "evalue": "Object of type 'TensorShape' is not JSON serializable",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-2-9db6ef1a76be>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     30\u001b[0m \u001b[0mdbn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpretrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msave\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     31\u001b[0m \u001b[0mdbn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfinetune\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 32\u001b[0;31m \u001b[0mdbn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msave_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"mnist_dbn_model.json\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     33\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     34\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Training Report\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-1-bb19d09f4f86>\u001b[0m in \u001b[0;36msave_model\u001b[0;34m(self, filename)\u001b[0m\n\u001b[1;32m    159\u001b[0m                 \u001b[0;34m\"loss_acc\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhistory\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhistory\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    160\u001b[0m             }\n\u001b[0;32m--> 161\u001b[0;31m             \u001b[0mjson\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdump\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutfile\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindent\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    162\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    163\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mload_rbm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.6/json/__init__.py\u001b[0m in \u001b[0;36mdump\u001b[0;34m(obj, fp, skipkeys, ensure_ascii, check_circular, allow_nan, cls, indent, separators, default, sort_keys, **kw)\u001b[0m\n\u001b[1;32m    177\u001b[0m     \u001b[0;31m# could accelerate with writelines in some versions of Python, at\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    178\u001b[0m     \u001b[0;31m# a debuggability cost\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 179\u001b[0;31m     \u001b[0;32mfor\u001b[0m \u001b[0mchunk\u001b[0m \u001b[0;32min\u001b[0m \u001b[0miterable\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    180\u001b[0m         \u001b[0mfp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwrite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mchunk\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    181\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.6/json/encoder.py\u001b[0m in \u001b[0;36m_iterencode\u001b[0;34m(o, _current_indent_level)\u001b[0m\n\u001b[1;32m    428\u001b[0m             \u001b[0;32myield\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0m_iterencode_list\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mo\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_current_indent_level\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    429\u001b[0m         \u001b[0;32melif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mo\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdict\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 430\u001b[0;31m             \u001b[0;32myield\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0m_iterencode_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mo\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_current_indent_level\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    431\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    432\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mmarkers\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.6/json/encoder.py\u001b[0m in \u001b[0;36m_iterencode_dict\u001b[0;34m(dct, _current_indent_level)\u001b[0m\n\u001b[1;32m    402\u001b[0m                 \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    403\u001b[0m                     \u001b[0mchunks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_iterencode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_current_indent_level\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 404\u001b[0;31m                 \u001b[0;32myield\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mchunks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    405\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mnewline_indent\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    406\u001b[0m             \u001b[0m_current_indent_level\u001b[0m \u001b[0;34m-=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.6/json/encoder.py\u001b[0m in \u001b[0;36m_iterencode_dict\u001b[0;34m(dct, _current_indent_level)\u001b[0m\n\u001b[1;32m    402\u001b[0m                 \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    403\u001b[0m                     \u001b[0mchunks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_iterencode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_current_indent_level\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 404\u001b[0;31m                 \u001b[0;32myield\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mchunks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    405\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mnewline_indent\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    406\u001b[0m             \u001b[0m_current_indent_level\u001b[0m \u001b[0;34m-=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.6/json/encoder.py\u001b[0m in \u001b[0;36m_iterencode\u001b[0;34m(o, _current_indent_level)\u001b[0m\n\u001b[1;32m    435\u001b[0m                     \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Circular reference detected\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    436\u001b[0m                 \u001b[0mmarkers\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mmarkerid\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mo\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 437\u001b[0;31m             \u001b[0mo\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_default\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mo\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    438\u001b[0m             \u001b[0;32myield\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0m_iterencode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mo\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_current_indent_level\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    439\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mmarkers\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.6/json/encoder.py\u001b[0m in \u001b[0;36mdefault\u001b[0;34m(self, o)\u001b[0m\n\u001b[1;32m    178\u001b[0m         \"\"\"\n\u001b[1;32m    179\u001b[0m         raise TypeError(\"Object of type '%s' is not JSON serializable\" %\n\u001b[0;32m--> 180\u001b[0;31m                         o.__class__.__name__)\n\u001b[0m\u001b[1;32m    181\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    182\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mencode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mo\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mTypeError\u001b[0m: Object of type 'TensorShape' is not JSON serializable"
     ]
    }
   ],
   "source": [
    "\n",
    "#     trainx = np.load(\"mnist_train.npy\")\n",
    "#     trainy = np.load(\"mnist_trainy.npy\")\n",
    "#     testx = np.load(\"mnist_test.npy\")\n",
    "#     testy = np.load(\"mnist_testy.npy\")\n",
    "\n",
    "import tensorflow as tf \n",
    "\n",
    "mnist = tf.keras.datasets.mnist\n",
    "(trainx, trainy), (testx, testy) = mnist.load_data()    \n",
    "\n",
    "trainx = trainx.reshape(trainx.shape[0],-1)/255.0\n",
    "testx = testx.reshape(testx.shape[0],-1)/255.0\n",
    "\n",
    "from tensorflow.keras.utils import to_categorical\n",
    "\n",
    "trainy  = to_categorical(trainy)\n",
    "testy  = to_categorical(testy)\n",
    "\n",
    "dbn = DBN(\n",
    "    train_data=trainx,\n",
    "    targets=trainy,\n",
    "    #test_data = testx, test_targets = testy,\n",
    "    layers=[200,200,200,200],\n",
    "    outputs=10,\n",
    "    rbm_iters=[40,40,40,40],\n",
    "    rbm_lr=[0.01,0.01,0.01,0.01],\n",
    "    outdir=\"mnistrbm/\",\n",
    "    logdir=\"mnistrbm_logs/\")\n",
    "dbn.pretrain(save=True)\n",
    "dbn.finetune()\n",
    "dbn.save_model(\"mnist_dbn_model.json\")\n",
    "\n",
    "print(\"Training Report\")\n",
    "dbn.report(trainx, trainy)\n",
    "\n",
    "print(\"Testing Report\")\n",
    "dbn.report(testx, testy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
