{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-16T06:48:21.955414Z",
     "start_time": "2023-05-16T06:48:20.981138Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision.datasets as datasets\n",
    "import torchvision.transforms as transforms\n",
    "import numpy as np\n",
    "import random\n",
    "from tqdm import trange\n",
    "\n",
    "class RBM:\n",
    "\n",
    "\tdef __init__(self, n_visible, n_hidden, lr=0.001, epochs=5, mode='bernoulli', batch_size=32, k=3, optimizer='adam', gpu=False, savefile=None, early_stopping_patience=5):\n",
    "\t\tself.mode = mode # bernoulli or gaussian RBM\n",
    "\t\tself.n_hidden = n_hidden #  Number of hidden nodes\n",
    "\t\tself.n_visible = n_visible # Number of visible nodes\n",
    "\t\tself.lr = lr # Learning rate for the CD algorithm\n",
    "\t\tself.epochs = epochs # Number of iterations to run the algorithm for\n",
    "\t\tself.batch_size = batch_size\n",
    "\t\tself.k = k\n",
    "\t\tself.optimizer = optimizer\n",
    "\t\tself.beta_1=0.9\n",
    "\t\tself.beta_2=0.999\n",
    "\t\tself.epsilon=1e-07\n",
    "\t\tself.m = [0, 0, 0]\n",
    "\t\tself.v = [0, 0, 0]\n",
    "\t\tself.m_batches = {0:[], 1:[], 2:[]}\n",
    "\t\tself.v_batches = {0:[], 1:[], 2:[]}\n",
    "\t\tself.savefile = savefile\n",
    "\t\tself.early_stopping_patience = early_stopping_patience\n",
    "\t\tself.stagnation = 0\n",
    "\t\tself.previous_loss_before_stagnation = 0\n",
    "\t\tself.progress = []\n",
    "\n",
    "\t\tif torch.cuda.is_available() and gpu==True:  \n",
    "\t\t\tdev = \"cuda:0\" \n",
    "\t\telse:  \n",
    "\t\t\tdev = \"cpu\"  \n",
    "\t\tself.device = torch.device(dev)\n",
    "\n",
    "\t\t# Initialize weights and biases\n",
    "\t\tstd = 4 * np.sqrt(6. / (self.n_visible + self.n_hidden))\n",
    "\t\tself.W = torch.normal(mean=0, std=std, size=(self.n_hidden, self.n_visible))\n",
    "\t\tself.vb = torch.zeros(size=(1, self.n_visible), dtype=torch.float32)\n",
    "\t\tself.hb = torch.zeros(size=(1, self.n_hidden), dtype=torch.float32)\n",
    "\n",
    "\t\tself.W = self.W.to(self.device)\n",
    "\t\tself.vb = self.vb.to(self.device)\n",
    "\t\tself.hb = self.hb.to(self.device)\n",
    "        \n",
    "\t\tself.W_list = []\n",
    "\t\tself.vb_list = []        \n",
    "\t\tself.hb_list = []        \n",
    "        \n",
    "\t\t\n",
    "\tdef sample_h(self, x):\n",
    "\t\twx = torch.mm(x, self.W.t())\n",
    "\t\tactivation = wx + self.hb\n",
    "\t\tp_h_given_v = torch.sigmoid(activation)\n",
    "\t\tif self.mode == 'bernoulli':\n",
    "\t\t\treturn p_h_given_v, torch.bernoulli(p_h_given_v)\n",
    "\t\telse:\n",
    "\t\t\treturn p_h_given_v, torch.add(p_h_given_v, torch.normal(mean=0, std=1, size=p_h_given_v.shape))\n",
    "\n",
    "\tdef sample_v(self, y):\n",
    "\t\twy = torch.mm(y, self.W)\n",
    "\t\tactivation = wy + self.vb\n",
    "\t\tp_v_given_h =torch.sigmoid(activation)\n",
    "\t\tif self.mode == 'bernoulli':\n",
    "\t\t\treturn p_v_given_h, torch.bernoulli(p_v_given_h)\n",
    "\t\telse:\n",
    "\t\t\treturn p_v_given_h, torch.add(p_v_given_h, torch.normal(mean=0, std=1, size=p_v_given_h.shape))\n",
    "\t\n",
    "\tdef adam(self, g, epoch, index):\n",
    "\t\tself.m[index] = self.beta_1 * self.m[index] + (1 - self.beta_1) * g\n",
    "\t\tself.v[index] = self.beta_2 * self.v[index] + (1 - self.beta_2) * torch.pow(g, 2)\n",
    "\n",
    "\t\tm_hat = self.m[index] / (1 - np.power(self.beta_1, epoch)) + (1 - self.beta_1) * g / (1 - np.power(self.beta_1, epoch))\n",
    "\t\tv_hat = self.v[index] / (1 - np.power(self.beta_2, epoch))\n",
    "\t\treturn m_hat / (torch.sqrt(v_hat) + self.epsilon)\n",
    "\n",
    "\tdef update(self, v0, vk, ph0, phk, epoch):\n",
    "\t\tdW = (torch.mm(v0.t(), ph0) - torch.mm(vk.t(), phk)).t()\n",
    "\t\tdvb = torch.sum((v0 - vk), 0)\n",
    "\t\tdhb = torch.sum((ph0 - phk), 0)\n",
    "\n",
    "\t\tif self.optimizer == 'adam':\n",
    "\t\t\tdW = self.adam(dW, epoch, 0)\n",
    "\t\t\tdvb = self.adam(dvb, epoch, 1)\n",
    "\t\t\tdhb = self.adam(dhb, epoch, 2)\n",
    "\n",
    "\t\tself.W += self.lr * dW\n",
    "\t\tself.vb += self.lr * dvb\n",
    "\t\tself.hb += self.lr * dhb\n",
    "        \n",
    "      \n",
    "        \n",
    "        \n",
    "\tdef train(self, dataset):\n",
    "\t\tdataset = dataset.to(self.device)\n",
    "\t\tlearning = trange(self.epochs, desc=str('Starting...'))\n",
    "\t\tfor epoch in learning:\n",
    "\t\t\ttrain_loss = 0\n",
    "\t\t\tcounter = 0\n",
    "\t\t\tfor batch_start_index in range(0, dataset.shape[0]-self.batch_size, self.batch_size):\n",
    "\t\t\t\tvk = dataset[batch_start_index:batch_start_index+self.batch_size]\n",
    "\t\t\t\tv0 = dataset[batch_start_index:batch_start_index+self.batch_size]\n",
    "\t\t\t\tph0, _ = self.sample_h(v0)\n",
    "\n",
    "\t\t\t\tfor k in range(self.k):\n",
    "\t\t\t\t\t_, hk = self.sample_h(vk)\n",
    "\t\t\t\t\t_, vk = self.sample_v(hk)\n",
    "\t\t\t\tphk, _ = self.sample_h(vk)\n",
    "\t\t\t\tself.update(v0, vk, ph0, phk, epoch+1)\n",
    "\t\t\t\ttrain_loss += torch.mean(torch.abs(v0-vk))\n",
    "\t\t\t\tcounter += 1\n",
    "\t\t\t\n",
    "\t\t\tself.progress.append(train_loss.item()/counter)\n",
    "\t\t\tdetails = {'epoch': epoch+1, 'loss': round(train_loss.item()/counter, 4)}\n",
    "\t\t\tlearning.set_description(str(details))\n",
    "\t\t\tlearning.refresh()\n",
    "# \t\t\tprint(epoch)\n",
    "# \t\t\tprint(self.W)\n",
    "# \t\t\tprint(self.W_list)\n",
    "\t\t\tself.W_list.append(self.W.clone())\n",
    "\t\t\tself.vb_list.append(self.vb.clone())        \n",
    "\t\t\tself.hb_list.append(self.hb.clone()) \n",
    "            \n",
    "            \n",
    "\t\t\tif train_loss.item()/counter > self.previous_loss_before_stagnation and epoch>self.early_stopping_patience+1:\n",
    "\t\t\t\tself.stagnation += 1\n",
    "\t\t\t\tif self.stagnation == self.early_stopping_patience-1:\n",
    "\t\t\t\t\tlearning.close()\n",
    "\t\t\t\t\tprint(\"Not Improving the stopping training loop.\")\n",
    "\t\t\t\t\tbreak\n",
    "\t\t\telse:\n",
    "\t\t\t\tself.previous_loss_before_stagnation = train_loss.item()/counter\n",
    "\t\t\t\tself.stagnation = 0\n",
    "\t\tlearning.close()\n",
    "\t\tif self.savefile is not None:\n",
    "\t\t\tmodel = {'W':self.W, 'vb':self.vb, 'hb':self.hb}\n",
    "\t\t\ttorch.save(model, self.savefile)\n",
    "\n",
    "            \n",
    "            \n",
    "            \n",
    "\tdef load_rbm(self, savefile):\n",
    "\t\tloaded = torch.load(savefile)\n",
    "\t\tself.W = loaded['W']\n",
    "\t\tself.vb = loaded['vb']\n",
    "\t\tself.hb = loaded['hb']\n",
    "\n",
    "\t\tself.W = self.W.to(self.device)\n",
    "\t\tself.vb = self.vb.to(self.device)\n",
    "\t\tself.hb = self.hb.to(self.device)\n",
    "\n",
    "\n",
    "\n",
    "def trial_dataset():\n",
    "\tdataset = []\n",
    "\tfor _ in range(1000):\n",
    "\t\tt = []\n",
    "\t\tfor _ in range(10):\n",
    "\t\t\tif random.random()>0.75:\n",
    "\t\t\t\tt.append(0)\n",
    "\t\t\telse:\n",
    "\t\t\t\tt.append(1)\n",
    "\t\tdataset.append(t)\n",
    "\n",
    "\tfor _ in range(1000):\n",
    "\t\tt = []\n",
    "\t\tfor _ in range(10):\n",
    "\t\t\tif random.random()>0.75:\n",
    "\t\t\t\tt.append(1)\n",
    "\t\t\telse:\n",
    "\t\t\t\tt.append(0)\n",
    "\t\tdataset.append(t)\n",
    "\n",
    "\tdataset = np.array(dataset, dtype=np.float32)\n",
    "\tnp.random.shuffle(dataset)\n",
    "\tdataset = torch.from_numpy(dataset)\n",
    "\treturn dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-16T06:48:22.003572Z",
     "start_time": "2023-05-16T06:48:21.961281Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import random\n",
    "from tqdm import trange\n",
    "#from RBM import RBM\n",
    "\n",
    "class DBN:\n",
    "\tdef __init__(self, input_size, layers, mode='bernoulli', gpu=False, k=5, savefile=None):\n",
    "\t\tself.layers = layers\n",
    "\t\tself.input_size = input_size\n",
    "\t\tself.layer_parameters = [{'W':None, 'hb':None, 'vb':None} for _ in range(len(layers))]\n",
    "\t\tself.layer_parameters_list = [{'W_list':None, 'hb_list':None, 'vb_list':None} for _ in range(len(layers))]\n",
    "\t\tself.k = k\n",
    "\t\tself.mode = mode\n",
    "\t\tself.savefile = savefile\n",
    "\n",
    "\tdef sample_v(self, y, W, vb):\n",
    "\t\twy = torch.mm(y, W)\n",
    "\t\tactivation = wy + vb\n",
    "\t\tp_v_given_h =torch.sigmoid(activation)\n",
    "\t\tif self.mode == 'bernoulli':\n",
    "\t\t\treturn p_v_given_h, torch.bernoulli(p_v_given_h)\n",
    "\t\telse:\n",
    "\t\t\treturn p_v_given_h, torch.add(p_v_given_h, torch.normal(mean=0, std=1, size=p_v_given_h.shape))\n",
    "\n",
    "\tdef sample_h(self, x, W, hb):\n",
    "\t\twx = torch.mm(x, W.t())\n",
    "\t\tactivation = wx + hb\n",
    "\t\tp_h_given_v = torch.sigmoid(activation)\n",
    "\t\tif self.mode == 'bernoulli':\n",
    "\t\t\treturn p_h_given_v, torch.bernoulli(p_h_given_v)\n",
    "\t\telse:\n",
    "\t\t\treturn p_h_given_v, torch.add(p_h_given_v, torch.normal(mean=0, std=1, size=p_h_given_v.shape))\n",
    "\n",
    "\tdef generate_input_for_layer(self, index, x):\n",
    "\t\tif index>0:\n",
    "\t\t\tx_gen = []\n",
    "\t\t\tfor _ in range(self.k):\n",
    "\t\t\t\tx_dash = x.clone()\n",
    "\t\t\t\tfor i in range(index):\n",
    "\t\t\t\t\t_, x_dash = self.sample_h(x_dash, self.layer_parameters[i]['W'], self.layer_parameters[i]['hb'])\n",
    "\t\t\t\tx_gen.append(x_dash)\n",
    "\n",
    "\t\t\tx_dash = torch.stack(x_gen)\n",
    "\t\t\tx_dash = torch.mean(x_dash, dim=0)\n",
    "\t\telse:\n",
    "\t\t\tx_dash = x.clone()\n",
    "\t\treturn x_dash\n",
    "\n",
    "\tdef train_DBN(self, x):\n",
    "\t\tfor index, layer in enumerate(self.layers):\n",
    "\t\t\tif index==0:\n",
    "\t\t\t\tvn = self.input_size\n",
    "\t\t\telse:\n",
    "\t\t\t\tvn = self.layers[index-1]\n",
    "\t\t\thn = self.layers[index]\n",
    "\n",
    "\t\t\trbm = RBM(vn, hn, epochs=100, mode='bernoulli', lr=0.0005, k=10, batch_size=128, gpu=True, optimizer='adam', early_stopping_patience=10)\n",
    "\t\t\tprint(len(rbm.W_list))\n",
    "\t\t\tx_dash = self.generate_input_for_layer(index, x)\n",
    "\t\t\trbm.train(x_dash)\n",
    "\t\t\tself.layer_parameters[index]['W'] = rbm.W.cpu()\n",
    "\t\t\tself.layer_parameters[index]['hb'] = rbm.hb.cpu()\n",
    "\t\t\tself.layer_parameters[index]['vb'] = rbm.vb.cpu()\n",
    "            \n",
    "# \t\t\tprint('W_list',rbm.W_list) \n",
    "# \t\t\tprint('W_listW_listW_listW_listW_listW_listW_listW_listW_listW_listW_list') \n",
    "            \n",
    "\t\t\tself.layer_parameters_list[index]['W_list'] = rbm.W_list\n",
    "\t\t\tself.layer_parameters_list[index]['hb_list'] = rbm.hb_list\n",
    "\t\t\tself.layer_parameters_list[index]['vb_list'] = rbm.vb_list\n",
    "            \n",
    "\n",
    "\t\t\tprint(\"Finished Training Layer:\", index, \"to\", index+1)\n",
    "\t\tif self.savefile is not None:\n",
    "\t\t\ttorch.save(self.layer_parameters, self.savefile)\n",
    "\n",
    "            \n",
    "            \n",
    "            \n",
    "            \n",
    "\tdef reconstructor(self, x):\n",
    "\t\tx_gen = []\n",
    "\t\tfor _ in range(self.k):\n",
    "\t\t\tx_dash = x.clone()\n",
    "\t\t\tfor i in range(len(self.layer_parameters)):\n",
    "\t\t\t\t_, x_dash = self.sample_h(x_dash, self.layer_parameters[i]['W'], self.layer_parameters[i]['hb'])\n",
    "\t\t\tx_gen.append(x_dash)\n",
    "\t\tx_dash = torch.stack(x_gen)\n",
    "\t\tx_dash = torch.mean(x_dash, dim=0)\n",
    "\n",
    "\t\ty = x_dash\n",
    "\n",
    "\t\ty_gen = []\n",
    "\t\tfor _ in range(self.k):\n",
    "\t\t\ty_dash = y.clone()\n",
    "\t\t\tfor i in range(len(self.layer_parameters)):\n",
    "\t\t\t\ti = len(self.layer_parameters)-1-i\n",
    "\t\t\t\t_, y_dash = self.sample_v(y_dash, self.layer_parameters[i]['W'], self.layer_parameters[i]['vb'])\n",
    "\t\t\ty_gen.append(y_dash)\n",
    "\t\ty_dash = torch.stack(y_gen)\n",
    "\t\ty_dash = torch.mean(y_dash, dim=0)\n",
    "\n",
    "\t\treturn y_dash, x_dash\n",
    "\n",
    "\n",
    "\tdef initialize_model(self):\n",
    "\t\tprint(\"The Last layer will not be activated. The rest are activated using the Sigoid Function\")\n",
    "\t\tmodules = []\n",
    "\t\tfor index, layer in enumerate(self.layer_parameters):\n",
    "\t\t\tmodules.append(torch.nn.Linear(layer['W'].shape[1], layer['W'].shape[0]))\n",
    "\t\t\tif index < len(self.layer_parameters) - 1:\n",
    "\t\t\t\tmodules.append(torch.nn.Sigmoid())\n",
    "\t\tmodel = torch.nn.Sequential(*modules)\n",
    "\t\tprint(\t\t)\n",
    "                     \n",
    "\t\tfor layer_no, layer in enumerate(model):\n",
    "\t\t\tfor epoch_i in range(100):\n",
    "\t\t\t\tif layer_no//2 == len(self.layer_parameters)-1:\n",
    "\t\t\t\t\tbreak\n",
    "\t\t\t\tif layer_no%2 == 0:\n",
    "\t\t\t\t\tself.layer_parameters[layer_no//2]['W'] = self.layer_parameters_list[layer_no//2]['W_list'][epoch_i].cpu()\n",
    "\t\t\t\t\tself.layer_parameters[layer_no//2]['hb'] = self.layer_parameters_list[layer_no//2]['hb_list'][epoch_i].cpu()\n",
    "# \t\t\t\t\tprint(self.layer_parameters_list[layer_no//2]['hb_list'][1].cpu())\n",
    "# \t\t\t\t\tprint(self.layer_parameters_list[layer_no//2]['hb_list'][2].cpu())\n",
    "# \t\t\t\t\tprint('#######################')\n",
    "\t\t\t\t\tmodel[layer_no].weight = torch.nn.Parameter(self.layer_parameters[layer_no//2]['W'])\n",
    "\t\t\t\t\tmodel[layer_no].bias = torch.nn.Parameter(self.layer_parameters[layer_no//2]['hb'])\n",
    "\n",
    "\t\t\t\t\tmodel_for_test = torch.nn.Sequential(model, torch.nn.Sigmoid())   \n",
    "\t\t\t\t\toutput_getter = LayerOutputGetter()\n",
    "\t\t\t\t\t# register the hook for each layer\n",
    "\t\t\t\t\tfor module in model_for_test.modules():\n",
    "\t\t\t\t\t\tmodule.register_forward_hook(output_getter)\n",
    "\n",
    "                    # run an example input through the model\n",
    "# \t\t\t\t\tinput_tensor = torch.randn(1, 784)\n",
    "\t\t\t\t\tmodel_output = model_for_test(Separability_images)\n",
    "\n",
    "\t\t\t\t\t# get the outputs for each layer\n",
    "\t\t\t\t\tlayer_outputs = output_getter.outputs\n",
    "\t\t\t\t\tdel layer_outputs[7]\n",
    "#\t\t\t\t\tlayer_outputs = layer_outputs[:-1]\n",
    "\t\t\t\t\titer_l = (layer_no//2)*100+epoch_i\n",
    "\t\t\t\t\tprint(iter_l)\n",
    "                         \n",
    "    \n",
    "\t\t\t\t\tfor i, output in enumerate(layer_outputs):\n",
    "\t\t\t\t\t\tout_put_numpy = output.detach().numpy()\n",
    "\t\t\t\t\t\tLS_1_squence[i,iter_l],LS_2_squence[i,iter_l],J_w_squence[i,iter_l],\\\n",
    "\t\t\t\t\t\tLDA_squence[i,iter_l]=W(out_put_numpy.reshape(out_put_numpy.shape[0],-1),Separability_labels)  \n",
    "# \t\t\t\t\t\tprint(f\"Output of layer {i+1}: {output}\")\n",
    "# \t\t\t\t\t\tprint(output.shape)  \n",
    "\n",
    "\t\t\t\t\toutput_train = model_for_test(train_x)\n",
    "\n",
    "\t\t\t\t\tloss_train = criterion(output_train, train_y)\n",
    "        \n",
    "\t\t\t\t\tacc_train = binary_accuracy(output_train, train_y)                 \n",
    "\t\t\t\t\tprint(output_train,train_y,acc_train)                     \n",
    "\t\t\t\t\toutput_test = model_for_test(test_x)\n",
    "\t\t\t\t\tloss_test = criterion(output_test, test_y)\n",
    "\t\t\t\t\tacc_test = binary_accuracy(output_test, test_y)   \n",
    "                    \n",
    "\t\t\t\t\ttrain_loss_squence[iter_l],train_accuracy_squence[iter_l]=loss_train,acc_train\n",
    "\t\t\t\t\ttest_loss_squence[iter_l],test_accuracy_squence[iter_l]=loss_test,acc_test                            \n",
    "                    \n",
    "                    \n",
    "                    \n",
    "\t\treturn model\n",
    "\n",
    "def trial_dataset():\n",
    "\tdataset = []\n",
    "\tfor _ in range(1000):\n",
    "\t\tt = []\n",
    "\t\tfor _ in range(10):\n",
    "\t\t\tif random.random()>0.75:\n",
    "\t\t\t\tt.append(0)\n",
    "\t\t\telse:\n",
    "\t\t\t\tt.append(1)\n",
    "\t\tdataset.append(t)\n",
    "\n",
    "\tfor _ in range(1000):\n",
    "\t\tt = []\n",
    "\t\tfor _ in range(10):\n",
    "\t\t\tif random.random()>0.75:\n",
    "\t\t\t\tt.append(1)\n",
    "\t\t\telse:\n",
    "\t\t\t\tt.append(0)\n",
    "\t\tdataset.append(t)\n",
    "\n",
    "\tdataset = np.array(dataset, dtype=np.float32)\n",
    "\tnp.random.shuffle(dataset)\n",
    "\tdataset = torch.from_numpy(dataset)\n",
    "\treturn dataset\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-16T06:48:22.020092Z",
     "start_time": "2023-05-16T06:48:22.007016Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import random\n",
    "#from DBN import DBN\n",
    "#from load_dataset import MNIST\n",
    "from tqdm import trange\n",
    "import pandas as pd\n",
    "\n",
    "def initialize_model():\n",
    "\tmodel = torch.nn.Sequential(\n",
    "\t\ttorch.nn.Linear(32*32*3, 512),\n",
    "\t\ttorch.nn.Sigmoid(),\n",
    "\t\ttorch.nn.Linear(512, 128),\n",
    "\t\ttorch.nn.Sigmoid(),\n",
    "\t\ttorch.nn.Linear(128, 64),\n",
    "\t\ttorch.nn.Sigmoid(),\n",
    "\t\ttorch.nn.Linear(64, 1),\n",
    "\t\ttorch.nn.Sigmoid(),\n",
    "\t)\n",
    "\treturn model\n",
    "\n",
    "def generate_batches(x, y, batch_size=64):\n",
    "\tx = x[:int(x.shape[0] - x.shape[0]%batch_size)]\n",
    "\tx = torch.reshape(x, (x.shape[0]//batch_size, batch_size, x.shape[1]))\n",
    "\ty = y[:int(y.shape[0] - y.shape[0]%batch_size)]\n",
    "\ty = torch.reshape(y, (y.shape[0]//batch_size, batch_size))\n",
    "\treturn {'x':x, 'y':y}\n",
    "\n",
    "\n",
    "        \n",
    "def test(model, train_x, train_y, test_x, test_y, epoch):\n",
    "\tcriterion = torch.nn.BCELoss()\n",
    "\n",
    "\toutput_test = model(test_x)\n",
    "\tloss_test = criterion(output_test, test_y).item()\n",
    "\tacc_test = binary_accuracy(output_test, test_y)    \n",
    "\n",
    "\toutput_train = model(train_x)\n",
    "\tloss_train = criterion(output_train, train_y).item()\n",
    "\tacc_train = binary_accuracy(output_train, train_y)     \n",
    "\n",
    "\treturn epoch, loss_test, loss_train, acc_test, acc_train\n",
    "\n",
    "\n",
    "def train(model, x, y, train_x, train_y, test_x, test_y, epochs=50):\n",
    "\tdataset = generate_batches(x, y)\n",
    "\n",
    "\tcriterion = torch.nn.BCELoss()\n",
    "\toptimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "\ttraining = trange(epochs)\n",
    "\tprogress = []\n",
    "\tfor epoch in training:\n",
    "\t\trunning_loss = 0\n",
    "\t\tacc = 0\n",
    "\t\tfor batch_x, target in zip(dataset['x'], dataset['y']):\n",
    "\t\t\toutput = model(batch_x)\n",
    "\t\t\tloss = criterion(output, target)\n",
    "\t\t\tacc += binary_accuracy(output, target) \n",
    "\t\t\toptimizer.zero_grad()\n",
    "\t\t\tloss.backward()\n",
    "\t\t\toptimizer.step()\n",
    "\t\t\trunning_loss += loss.item()\n",
    "\t\trunning_loss /= len(dataset['y'])\n",
    "\t\tacc /= len(dataset['y'])\n",
    "\t\tprogress.append(test(model, train_x, train_y, test_x, test_y, epoch+1))\n",
    "\t\ttraining.set_description(str({'epoch': epoch+1, 'loss': round(running_loss, 4), 'acc': round(acc, 4)}))\n",
    "\n",
    "\treturn model, progress\n",
    "\n",
    "\n",
    "class LayerOutputGetter:\n",
    "    def __init__(self):\n",
    "        self.outputs = []\n",
    "\n",
    "    def __call__(self, module, input, output):\n",
    "        self.outputs.append(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-16T06:48:22.536968Z",
     "start_time": "2023-05-16T06:48:22.022904Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from Linear_Separability_numpy import *\n",
    "\n",
    "num_epochs = 100*3+100 \n",
    "len_x = 9\n",
    "x_plot = np.arange(num_epochs)*1\n",
    "\n",
    "LS_1_squence = np.zeros((len_x,num_epochs))\n",
    "LS_2_squence = np.zeros((len_x,num_epochs))\n",
    "J_w_squence = np.zeros((len_x,num_epochs))\n",
    "LDA_squence = np.zeros((len_x,num_epochs))\n",
    "\n",
    "train_loss_squence = np.zeros((num_epochs,))\n",
    "train_accuracy_squence = np.zeros((num_epochs,))\n",
    "test_loss_squence = np.zeros((num_epochs,))\n",
    "test_accuracy_squence = np.zeros((num_epochs,))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-16T06:48:38.573419Z",
     "start_time": "2023-05-16T06:48:22.539792Z"
    }
   },
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "def load_dataset():\n",
    "    # define the transformations to be applied to the images\n",
    "    transform = transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n",
    "        transforms.Lambda(lambda x: x * 0.5 + 0.5)\n",
    "    ])\n",
    "    # create train and test datasets\n",
    "    train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)\n",
    "    test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)\n",
    "\n",
    "    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=len(train_dataset), shuffle=False)\n",
    "    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=len(test_dataset), shuffle=False)\n",
    "\n",
    "    train_x, train_y = next(iter(train_loader))\n",
    "    test_x, test_y = next(iter(test_loader))\n",
    "\n",
    "    # transpose data to match format (num_samples, channels, height, width)\n",
    "    train_x = train_x.permute(0, 3, 1, 2)\n",
    "    test_x = test_x.permute(0, 3, 1, 2)\n",
    "\n",
    "    return train_x, train_y, test_x, test_y\n",
    "\n",
    "\n",
    "train_x, train_y, test_x, test_y = load_dataset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-16T06:48:38.579105Z",
     "start_time": "2023-05-16T06:48:38.575611Z"
    }
   },
   "outputs": [],
   "source": [
    "\n",
    "def binary_accuracy(preds, y):\n",
    "    rounded_preds = torch.round(preds)\n",
    "    correct = (rounded_preds == y).float()\n",
    "    accuracy = correct.sum() / len(correct)\n",
    "    return accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-16T06:48:38.600581Z",
     "start_time": "2023-05-16T06:48:38.582195Z"
    }
   },
   "outputs": [],
   "source": [
    "train_y = train_y.reshape(train_y.shape[0],-1)\n",
    "test_y = test_y.reshape(test_y.shape[0],-1)\n",
    "train_y = train_y.to(torch.float32)\n",
    "test_y = test_y.to(torch.float32)\n",
    "train_y.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-16T06:48:38.741098Z",
     "start_time": "2023-05-16T06:48:38.602950Z"
    }
   },
   "outputs": [],
   "source": [
    "train_x = train_x.reshape(train_x.shape[0],-1)\n",
    "test_x = test_x.reshape(test_x.shape[0],-1)\n",
    "train_x = train_x.to(torch.float32)\n",
    "test_x = test_x.to(torch.float32)\n",
    "train_x.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-16T06:48:38.752398Z",
     "start_time": "2023-05-16T06:48:38.745306Z"
    }
   },
   "outputs": [],
   "source": [
    "train_x.shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-16T06:48:38.760133Z",
     "start_time": "2023-05-16T06:48:38.755753Z"
    }
   },
   "outputs": [],
   "source": [
    "criterion = torch.nn.BCELoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-16T06:48:54.302324Z",
     "start_time": "2023-05-16T06:48:38.763333Z"
    },
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "if __name__ == '__main__':\n",
    "# \tmnist = MNIST()\n",
    "# \ttrain_x, train_y, test_x, test_y = mnist.load_dataset()\n",
    "\n",
    "\tall_index_train = np.where((train_y==0) + (train_y==1))[0]\n",
    "\tall_index_test = np.where((test_y==0) + (test_y==1))[0]\n",
    "\tnp.random.seed(1234)\n",
    "\ttrain_index = all_index_train[np.random.randint(0,len(all_index_train),2000)]\n",
    "\ttest_index = all_index_test[np.random.randint(0,len(all_index_test),1000)]\n",
    "\tLS_index = train_index[np.random.randint(0,len(train_index),500)]\n",
    "\n",
    "\tSeparability_images,Separability_labels = train_x[LS_index,:], train_y[LS_index,]\n",
    "\n",
    "\tlayers = [512, 128, 64, 1]\n",
    "\n",
    "\ttrain_x,train_y = train_x[train_index,:],train_y[train_index]\n",
    "\ttest_x,test_y = test_x[test_index,:],test_y[test_index]    \n",
    "    \n",
    "\tdbn = DBN(train_x.shape[1], layers, savefile='mnist_trained_dbn.pt')\n",
    "\tdbn.train_DBN(train_x)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-16T06:51:10.940886Z",
     "start_time": "2023-05-16T06:48:54.305293Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "model = dbn.initialize_model()\n",
    "\n",
    "completed_model = torch.nn.Sequential(model, torch.nn.Sigmoid())\n",
    "torch.save(completed_model, 'mnist_trained_dbn_classifier.pt')\n",
    "print(completed_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-16T06:51:11.003200Z",
     "start_time": "2023-05-16T06:51:10.979845Z"
    }
   },
   "outputs": [],
   "source": [
    "LS_1_squence_1 = LS_1_squence.copy()\n",
    "LS_2_squence_1 = LS_2_squence.copy()\n",
    "J_w_squence_1 = J_w_squence.copy()\n",
    "LDA_squence_1 = LDA_squence.copy()\n",
    "train_loss_squence_1 = train_loss_squence.copy()\n",
    "train_accuracy_squence_1 = train_accuracy_squence.copy() \n",
    "test_loss_squence_1 = test_loss_squence.copy()\n",
    "test_accuracy_squence_1 = test_accuracy_squence.copy()\n",
    "\n",
    "def train_a(name,model, x, y, train_x, train_y, test_x, test_y, epochs=100):\n",
    "\tdataset = generate_batches(x, y)\n",
    "\n",
    "\tcriterion = torch.nn.BCELoss()\n",
    "\toptimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "\ttraining = trange(epochs)\n",
    "\tprogress = []\n",
    "\tfor epoch in training:\n",
    "\t\trunning_loss = 0\n",
    "\t\tacc = 0\n",
    "\t\tfor batch_x, target in zip(dataset['x'], dataset['y']):\n",
    "\n",
    "\t\t\toutput = model(batch_x)\n",
    "\t\t\toutput = output.reshape(-1,)\n",
    "\t\t\tloss = criterion(output, target)\n",
    "\t\t\tacc += binary_accuracy(output, target)\n",
    "\t\t\toptimizer.zero_grad()\n",
    "\t\t\tloss.backward()\n",
    "\t\t\toptimizer.step()\n",
    "\t\t\trunning_loss += loss.item()\n",
    "\t\trunning_loss /= len(dataset['y'])\n",
    "\t\tacc /= len(dataset['y'])\n",
    "\n",
    "\n",
    "\t\tprogress.append(test(model, train_x, train_y, test_x, test_y, epoch+1))\n",
    "\t\ttraining.set_description(str({'epoch': epoch+1, 'loss': running_loss, 'acc': acc}))\n",
    "\n",
    "\t\toutput_getter = LayerOutputGetter()\n",
    "\t\t# register the hook for each layer\n",
    "\t\tfor module in model.modules():\n",
    "\t\t\tmodule.register_forward_hook(output_getter)\n",
    "\n",
    "                    # run an example input through the model\n",
    "# \t\tinput_tensor = torch.randn(1, 784)\n",
    "\t\tmodel_output = model(Separability_images)\n",
    "\n",
    "\t\t# get the outputs for each layer\n",
    "\t\tlayer_outputs = output_getter.outputs\n",
    "\n",
    "\t\titer_l = 300+epoch\n",
    "\n",
    "                    \n",
    "# \t\tprint(len(layer_outputs))   \n",
    "\t\tif name=='Pre-Training':\n",
    "# \t\t\tprint(layer_outputs[7])\n",
    "\t\t\tdel layer_outputs[7]\n",
    "# \t\t\tprint(layer_outputs[7])\n",
    "            \n",
    "\t\tfor i, output in enumerate(layer_outputs):\n",
    "\n",
    "\t\t\tout_put_numpy = output.detach().numpy()\n",
    "\t\t\tLS_1_squence_1[i,iter_l],LS_2_squence_1[i,iter_l],J_w_squence_1[i,iter_l],\\\n",
    "\t\t\tLDA_squence_1[i,iter_l]=W(out_put_numpy.reshape(out_put_numpy.shape[0],-1),Separability_labels)          \n",
    "       \n",
    "\t\toutput_train = model(train_x)\n",
    "#\t\toutput_train = output_train.reshape(-1,)\n",
    "\t\tloss_train = criterion(output_train, train_y)\n",
    "\t\tacc_train = binary_accuracy(output_train, train_y)            \n",
    "                    \n",
    "\t\toutput_test = model(test_x)\n",
    "#\t\toutput_test = output_test.reshape(-1,)\n",
    "\t\tloss_test = criterion(output_test, test_y)\n",
    "\t\tacc_test = binary_accuracy(output_test, test_y)    \n",
    "                    \n",
    "\t\ttrain_loss_squence_1[iter_l],train_accuracy_squence_1[iter_l]=loss_train,acc_train\n",
    "\t\ttest_loss_squence_1[iter_l],test_accuracy_squence_1[iter_l]=loss_test,acc_test       \n",
    "    \n",
    "    \n",
    "\treturn model, progress"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-16T06:51:11.019803Z",
     "start_time": "2023-05-16T06:51:11.005698Z"
    },
    "code_folding": []
   },
   "outputs": [],
   "source": [
    "LS_1_squence_2 = LS_1_squence.copy()\n",
    "LS_2_squence_2 = LS_2_squence.copy()\n",
    "J_w_squence_2 = J_w_squence.copy()\n",
    "LDA_squence_2 = LDA_squence.copy()\n",
    "train_loss_squence_2 = train_loss_squence.copy()\n",
    "train_accuracy_squence_2 = train_accuracy_squence.copy() \n",
    "test_loss_squence_2 = test_loss_squence.copy()\n",
    "test_accuracy_squence_2 = test_accuracy_squence.copy()\n",
    "\n",
    "def train_b(name,model, x, y, train_x, train_y, test_x, test_y, epochs=100):\n",
    "\tdataset = generate_batches(x, y)\n",
    "\n",
    "\tcriterion = torch.nn.BCELoss()\n",
    "\toptimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "\ttraining = trange(epochs)\n",
    "\tprogress = []\n",
    "\tfor epoch in training:\n",
    "\t\trunning_loss = 0\n",
    "\t\tacc = 0\n",
    "\t\tfor batch_x, target in zip(dataset['x'], dataset['y']):\n",
    "\t\t\toutput = model(batch_x)\n",
    "\t\t\toutput = output.reshape(-1,)\n",
    "\t\t\tloss = criterion(output, target)\n",
    "\t\t\toutput = output.reshape(-1,)\n",
    "\t\t\tacc += binary_accuracy(output, target)\n",
    "\t\t\toptimizer.zero_grad()\n",
    "\t\t\tloss.backward()\n",
    "\t\t\toptimizer.step()\n",
    "\t\t\trunning_loss += loss.item()\n",
    "\t\trunning_loss /= len(dataset['y'])\n",
    "\t\tacc /= len(dataset['y'])\n",
    "\n",
    "\t\tprogress.append(test(model, train_x, train_y, test_x, test_y, epoch+1))\n",
    "\t\ttraining.set_description(str({'epoch': epoch+1, 'loss': running_loss, 'acc':acc}))\n",
    "\n",
    "\t\toutput_getter = LayerOutputGetter()\n",
    "\t\t# register the hook for each layer\n",
    "\t\tfor module in model.modules():\n",
    "\t\t\tmodule.register_forward_hook(output_getter)\n",
    "\n",
    "                    # run an example input through the model\n",
    "# \t\tinput_tensor = torch.randn(1, 784)\n",
    "\t\tmodel_output = model(Separability_images)\n",
    "\n",
    "\t\t# get the outputs for each layer\n",
    "\t\tlayer_outputs = output_getter.outputs\n",
    "\n",
    "\t\titer_l = 300+epoch\n",
    "\n",
    "                    \n",
    "# \t\tprint(len(layer_outputs))   \n",
    "\t\tif name=='Pre-Training':\n",
    "# \t\t\tprint(layer_outputs[7])\n",
    "\t\t\tdel layer_outputs[7]\n",
    "# \t\t\tprint(layer_outputs[7])\n",
    "            \n",
    "\t\tfor i, output in enumerate(layer_outputs):\n",
    "\n",
    "\t\t\tout_put_numpy = output.detach().numpy()\n",
    "\t\t\tLS_1_squence_2[i,iter_l],LS_2_squence_2[i,iter_l],J_w_squence_2[i,iter_l],\\\n",
    "\t\t\tLDA_squence_2[i,iter_l]=W(out_put_numpy.reshape(out_put_numpy.shape[0],-1),Separability_labels)          \n",
    "       \n",
    "\t\toutput_train = model(train_x)\n",
    "#\t\toutput_train = output_train.reshape(-1,)\n",
    "\t\tloss_train = criterion(output_train, train_y)\n",
    "\t\tacc_train = binary_accuracy(output_train, train_y)            \n",
    "                    \n",
    "\t\toutput_test = model(test_x)\n",
    "#\t\toutput_test = output_test.reshape(-1,)\n",
    "\t\tloss_test = criterion(output_test, test_y)\n",
    "\t\tacc_test = binary_accuracy(output_test, test_y)     \n",
    "                    \n",
    "\t\ttrain_loss_squence_2[iter_l],train_accuracy_squence_2[iter_l]=loss_train,acc_train\n",
    "\t\ttest_loss_squence_2[iter_l],test_accuracy_squence_2[iter_l]=loss_test,acc_test       \n",
    "    \n",
    "    \n",
    "\treturn model, progress"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-16T06:52:22.261212Z",
     "start_time": "2023-05-16T06:51:11.022307Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "print('\\n'*3)\n",
    "print(\"Without Pre-Training\")\n",
    "model = initialize_model()\n",
    "name = 'initialize_model'\n",
    "model, progress = train_a(name,model, train_x, train_y, train_x, train_y, test_x, test_y)\n",
    "progress = pd.DataFrame(np.array(progress))\n",
    "progress.columns = ['epochs', 'test loss', 'train loss', 'test acc', 'train acc']\n",
    "progress.to_csv('DBN_without_pretraining_classifier.csv', index=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2023-05-16T06:48:22.323Z"
    }
   },
   "outputs": [],
   "source": [
    "print(\"With Pre-Training\")\n",
    "name = 'Pre-Training'\n",
    "model, progress = train_b(name,completed_model, train_x, train_y, train_x, train_y, test_x, test_y)\n",
    "progress = pd.DataFrame(np.array(progress))\n",
    "progress.columns = ['epochs', 'test loss', 'train loss', 'test acc', 'train acc']\n",
    "progress.to_csv('DBN_with_pretraining_classifier.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2023-05-16T06:48:22.325Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "LS_1_squence_base = np.zeros((1,num_epochs))\n",
    "LS_2_squence_base = np.zeros((1,num_epochs))\n",
    "J_w_squence_base = np.zeros((1,num_epochs))\n",
    "LDA_squence_base = np.zeros((1,num_epochs))\n",
    "LS_1_squence_base[0,:],LS_2_squence_base[0,:],\\\n",
    "J_w_squence_base[0,:],\\\n",
    "LDA_squence_base[0,:]=W(Separability_images.reshape(Separability_images.shape[0],-1).numpy(),Separability_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2023-05-16T06:48:22.327Z"
    }
   },
   "outputs": [],
   "source": [
    "layer_name_list = ['1-dense','Sigmoid','2-dense','Sigmoid','3-dense','Sigmoid','4-dense','Softmax','Softmax']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2023-05-16T06:48:22.329Z"
    }
   },
   "outputs": [],
   "source": [
    "import warnings\n",
    "import time\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "time_tuple = time.localtime(time.time())\n",
    "\n",
    "info={'layer_name_list':layer_name_list,'x_plot':x_plot,'LS_1_squence':LS_1_squence_1,'LS_2_squence':LS_2_squence_1,'J_w_squence':J_w_squence_1,\n",
    "            'LDA_squence':LDA_squence_1,'train_loss_squence':train_loss_squence_1,'train_accuracy_squence':train_accuracy_squence_1,'test_loss_squence':test_loss_squence_1,'test_accuracy_squence':test_accuracy_squence_1,'LS_1_squence_base':LS_1_squence_base,'LS_2_squence_base':LS_2_squence_base,'J_w_squence_base':J_w_squence_base,\n",
    "            'LDA_squence_base':LDA_squence_base,}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2023-05-16T06:48:22.330Z"
    }
   },
   "outputs": [],
   "source": [
    "import warnings\n",
    "import time\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "time_tuple = time.localtime(time.time())\n",
    "\n",
    "info={'layer_name_list':layer_name_list,'x_plot':x_plot,'LS_1_squence':LS_1_squence_2,'LS_2_squence':LS_2_squence_2,'J_w_squence':J_w_squence_2,\n",
    "            'LDA_squence':LDA_squence_2,'train_loss_squence':train_loss_squence_2,'train_accuracy_squence':train_accuracy_squence_2,'test_loss_squence':test_loss_squence_2,'test_accuracy_squence':test_accuracy_squence_2,'LS_1_squence_base':LS_1_squence_base,'LS_2_squence_base':LS_2_squence_base,'J_w_squence_base':J_w_squence_base,\n",
    "            'LDA_squence_base':LDA_squence_base,}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2023-05-16T06:48:22.331Z"
    }
   },
   "outputs": [],
   "source": [
    "Separability_figure = plot_Separability_figure(layer_name_list,x_plot,LS_1_squence_1,LS_2_squence_1,J_w_squence_1,LDA_squence_1,LS_1_squence_base,LS_2_squence_base,J_w_squence_base,LDA_squence_base)\n",
    "net_figure = plot_net_figure(layer_name_list,x_plot,train_loss_squence_1,train_accuracy_squence_1,test_loss_squence_1,test_accuracy_squence_1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2023-05-16T06:48:22.334Z"
    }
   },
   "outputs": [],
   "source": [
    "Separability_figure = plot_Separability_figure(layer_name_list,x_plot,LS_1_squence_2,LS_2_squence_2,J_w_squence_2,LDA_squence_2,LS_1_squence_base,LS_2_squence_base,J_w_squence_base,LDA_squence_base)\n",
    "net_figure = plot_net_figure(layer_name_list,x_plot,train_loss_squence_2,train_accuracy_squence_2,test_loss_squence_2,test_accuracy_squence_2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Pytorch",
   "language": "python",
   "name": "pytorch10"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
