{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "import math\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.optim import SGD, RMSprop\n",
    "from torchvision import datasets, transforms\n",
    "dtype = torch.cuda.FloatTensor\n",
    "\n",
    "from src.components import DeterministicTensor, StochasticTensor, StochasticNetwork, FF_BNN, SVHN_BCNN, BayesianResNet20\n",
    "from src.optimizers import SGLD, pSGLD\n",
    "from sklearn import preprocessing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x7fa574508e30>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pickle\n",
    "\n",
    "random.seed(2)\n",
    "np.random.seed(2)\n",
    "torch.manual_seed(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "#https://towardsdatascience.com/handwritten-digit-mnist-pytorch-977b5338e627\n",
    "# dev = torch.device('cpu')\n",
    "dev = torch.device('cuda:0')\n",
    "\n",
    "transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))])\n",
    "\n",
    "trainset = datasets.CIFAR10('PATH_TO_STORE_TRAINSET', download=True, train=True, transform=transform)\n",
    "valset = datasets.CIFAR10('PATH_TO_STORE_TESTSET', download=True, train=False, transform=transform)\n",
    "\n",
    "trainset = [(x.to(dev), torch.tensor(y, device=dev)) for x,y in trainset]\n",
    "valset = [(x.to(dev), torch.tensor(y, device=dev)) for x,y in valset]\n",
    "\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)\n",
    "testloader = torch.utils.data.DataLoader(valset, batch_size=128, shuffle=True)\n",
    "N = len(trainset) #for training\n",
    "# N = len(valset) #for testing\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128])\n"
     ]
    }
   ],
   "source": [
    "dataiter = iter(testloader)\n",
    "images, labels = dataiter.next()\n",
    "\n",
    "print(images.shape)\n",
    "print(labels.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "###### First we run the model with SGD in order to find the MAP solution. Then we run it again with SGLD\n",
    "# in order to find the posterior distribution.\n",
    "\n",
    "model_arch_args = dict(\n",
    "#     num_inputs=iter(trainloader).next()[0].shape[-1]*iter(trainloader).next()[0].shape[-2]*iter(trainloader).next()[0].shape[1],\n",
    "#     num_outputs=10,\n",
    "#     num_layers=2,\n",
    "#     hidden_sizes=[50,50],\n",
    "#     activation_func=nn.ReLU,#nn.Tanh, #nn.ReLU,\n",
    "#     chain_length=4000,\n",
    "#     stochastic_biases=False,\n",
    "#     prior_std = 0.3,\n",
    "#     output_distribution=\"categorical\",\n",
    "#     output_dist_const_params=dict(), #scale=1.0),\n",
    ")\n",
    "\n",
    "sgd_model_args = dict(\n",
    "    group_by_layers=False,\n",
    "    use_random_groups=False,\n",
    "    use_permuted_groups=False,\n",
    "    max_groups=None,\n",
    "    dropout_prob=None,\n",
    "    **model_arch_args,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "sgd_model = BayesianResNet20(**sgd_model_args)\n",
    "sgd_model.initialize_optimizer(\n",
    "    update_determ=True, \n",
    "    update_stoch=True, \n",
    "#     lr=1e-8, #1e-5, \n",
    "    lr=1e-3, \n",
    "    rmsprop=True,\n",
    "    sgd=False, \n",
    "    sgld=False, \n",
    "    psgld=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "pickle.dump(sgd_model_args, open(\"./resnet20_sgd_model_params.pickle\", \"wb\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# dev = torch.device('cpu')\n",
    "\n",
    "# for images, labels in trainloader:\n",
    "#     images = images.to(dev)\n",
    "#     labels = labels.to(dev)\n",
    "\n",
    "#trainset = [(x.to(dev), torch.tensor(y, device=dev).unsqueeze(0)) for x,y in trainloader]\n",
    "\n",
    "sgd_model = sgd_model.to(dev)\n",
    "\n",
    "for n, t in sgd_model.tensor_dict.items():\n",
    "    if isinstance(t, StochasticTensor):\n",
    "        t.prior_dist.loc = t.prior_dist.loc.to(dev)\n",
    "        t.prior_dist.scale = t.prior_dist.scale.to(dev)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iter 1 / 2000, Loss: 4421362534.5, CrossEntropy: 1.7264758348464966, Accuracy: 0.350943094629156\n",
      "Iter 2 / 2000, Loss: 3291276845.5, CrossEntropy: 1.2746440172195435, Accuracy: 0.5322290601023018\n",
      "Iter 3 / 2000, Loss: 2680235659.75, CrossEntropy: 1.030050277709961, Accuracy: 0.6300431585677749\n",
      "Iter 4 / 2000, Loss: 2285055144.0, CrossEntropy: 0.8718842267990112, Accuracy: 0.6863890664961637\n",
      "Iter 5 / 2000, Loss: 1982837788.25, CrossEntropy: 0.7512186765670776, Accuracy: 0.733611732736573\n",
      "Iter 6 / 2000, Loss: 1760809752.5, CrossEntropy: 0.6623088121414185, Accuracy: 0.768454283887468\n",
      "Iter 7 / 2000, Loss: 1591924834.25, CrossEntropy: 0.5949235558509827, Accuracy: 0.7922634271099743\n",
      "Iter 8 / 2000, Loss: 1450012311.5, CrossEntropy: 0.5379517674446106, Accuracy: 0.8116008631713555\n",
      "Iter 9 / 2000, Loss: 1335786847.75, CrossEntropy: 0.49239733815193176, Accuracy: 0.8274976023017903\n",
      "Iter 10 / 2000, Loss: 1232274439.5, CrossEntropy: 0.4509517252445221, Accuracy: 0.8416520140664961\n",
      "Iter 11 / 2000, Loss: 1133404593.625, CrossEntropy: 0.4113626182079315, Accuracy: 0.8573329603580563\n",
      "Iter 12 / 2000, Loss: 1047738266.5, CrossEntropy: 0.3769586980342865, Accuracy: 0.8680226982097188\n",
      "Iter 13 / 2000, Loss: 964345700.375, CrossEntropy: 0.3438212275505066, Accuracy: 0.8799552429667519\n",
      "Iter 14 / 2000, Loss: 896155176.25, CrossEntropy: 0.31654056906700134, Accuracy: 0.8888147378516624\n",
      "Iter 15 / 2000, Loss: 812552368.5, CrossEntropy: 0.28320184350013733, Accuracy: 0.9012907608695653\n",
      "Iter 16 / 2000, Loss: 752092529.75, CrossEntropy: 0.2588290572166443, Accuracy: 0.9079763427109975\n",
      "Iter 17 / 2000, Loss: 697492774.8125, CrossEntropy: 0.23702989518642426, Accuracy: 0.9169876918158568\n",
      "Iter 18 / 2000, Loss: 637185400.3125, CrossEntropy: 0.21294960379600525, Accuracy: 0.9240489130434784\n",
      "Iter 19 / 2000, Loss: 595524482.375, CrossEntropy: 0.19628016650676727, Accuracy: 0.9307424872122763\n",
      "Iter 20 / 2000, Loss: 548104821.0625, CrossEntropy: 0.17719672620296478, Accuracy: 0.9372642263427109\n",
      "Iter 21 / 2000, Loss: 504946801.5, CrossEntropy: 0.1600395292043686, Accuracy: 0.943869884910486\n",
      "Iter 22 / 2000, Loss: 468633526.25, CrossEntropy: 0.1453671157360077, Accuracy: 0.9482336956521739\n",
      "Iter 23 / 2000, Loss: 447089944.3125, CrossEntropy: 0.1367846429347992, Accuracy: 0.9514585997442455\n",
      "Iter 24 / 2000, Loss: 415171909.3125, CrossEntropy: 0.12402287125587463, Accuracy: 0.9552429667519181\n",
      "Iter 25 / 2000, Loss: 385903064.40625, CrossEntropy: 0.11222892254590988, Accuracy: 0.959494884910486\n",
      "Iter 26 / 2000, Loss: 377662948.71875, CrossEntropy: 0.10895416140556335, Accuracy: 0.9604699488491049\n",
      "Iter 27 / 2000, Loss: 360847250.46875, CrossEntropy: 0.1022074967622757, Accuracy: 0.9634271099744245\n",
      "Iter 28 / 2000, Loss: 346743179.625, CrossEntropy: 0.09662183374166489, Accuracy: 0.9648377557544756\n",
      "Iter 29 / 2000, Loss: 330009482.75, CrossEntropy: 0.08984000980854034, Accuracy: 0.9685461956521739\n",
      "Iter 30 / 2000, Loss: 316161574.5625, CrossEntropy: 0.08436132967472076, Accuracy: 0.9703204923273657\n",
      "Iter 31 / 2000, Loss: 306310897.78125, CrossEntropy: 0.0803227350115776, Accuracy: 0.9717071611253197\n",
      "Iter 32 / 2000, Loss: 292295061.40625, CrossEntropy: 0.07473962008953094, Accuracy: 0.9732416879795397\n",
      "Iter 33 / 2000, Loss: 291603256.15625, CrossEntropy: 0.07444943487644196, Accuracy: 0.9732656649616368\n",
      "Iter 34 / 2000, Loss: 275917955.8125, CrossEntropy: 0.06822313368320465, Accuracy: 0.9755035166240409\n",
      "Iter 35 / 2000, Loss: 276053091.3125, CrossEntropy: 0.0683625191450119, Accuracy: 0.9770540281329922\n",
      "Iter 36 / 2000, Loss: 266327425.28125, CrossEntropy: 0.06431802362203598, Accuracy: 0.9769501278772379\n",
      "Iter 37 / 2000, Loss: 260121667.875, CrossEntropy: 0.06184309720993042, Accuracy: 0.9782688618925831\n",
      "Iter 38 / 2000, Loss: 253315805.3125, CrossEntropy: 0.05912578105926514, Accuracy: 0.9785366048593351\n",
      "Iter 39 / 2000, Loss: 249484735.90625, CrossEntropy: 0.05761607363820076, Accuracy: 0.9791120524296675\n",
      "Iter 40 / 2000, Loss: 246039297.3125, CrossEntropy: 0.05622851848602295, Accuracy: 0.9800791240409207\n",
      "Iter 41 / 2000, Loss: 244317410.3125, CrossEntropy: 0.05546925216913223, Accuracy: 0.9808543797953965\n",
      "Iter 42 / 2000, Loss: 245710550.875, CrossEntropy: 0.055982306599617004, Accuracy: 0.97965952685422\n",
      "Iter 43 / 2000, Loss: 231421565.875, CrossEntropy: 0.05026863515377045, Accuracy: 0.9825967071611253\n",
      "Iter 44 / 2000, Loss: 229206156.15625, CrossEntropy: 0.04955284297466278, Accuracy: 0.9820692135549872\n",
      "Iter 45 / 2000, Loss: 227273144.78125, CrossEntropy: 0.04865493252873421, Accuracy: 0.9829124040920717\n",
      "Iter 46 / 2000, Loss: 224256751.1875, CrossEntropy: 0.04742730036377907, Accuracy: 0.9830043158567775\n",
      "Iter 47 / 2000, Loss: 222954135.9375, CrossEntropy: 0.04688316211104393, Accuracy: 0.9837236253196932\n",
      "Iter 48 / 2000, Loss: 222560381.0, CrossEntropy: 0.046739473938941956, Accuracy: 0.9838315217391305\n",
      "Iter 49 / 2000, Loss: 220294686.46875, CrossEntropy: 0.04578021168708801, Accuracy: 0.9840632992327366\n",
      "Iter 50 / 2000, Loss: 212894531.71875, CrossEntropy: 0.04296747222542763, Accuracy: 0.9849184782608695\n",
      "Iter 51 / 2000, Loss: 212656436.0625, CrossEntropy: 0.04275834187865257, Accuracy: 0.9853820332480818\n",
      "Iter 52 / 2000, Loss: 198255830.625, CrossEntropy: 0.03697028383612633, Accuracy: 0.9865808823529412\n",
      "Iter 53 / 2000, Loss: 208691586.65625, CrossEntropy: 0.04115442559123039, Accuracy: 0.9856098145780052\n",
      "Iter 54 / 2000, Loss: 211260857.0, CrossEntropy: 0.04216769337654114, Accuracy: 0.9853101023017904\n",
      "Iter 55 / 2000, Loss: 197804472.96875, CrossEntropy: 0.036761440336704254, Accuracy: 0.9867806905370844\n",
      "Iter 56 / 2000, Loss: 197729380.78125, CrossEntropy: 0.03681652992963791, Accuracy: 0.9870644181585677\n",
      "Iter 57 / 2000, Loss: 197924414.0625, CrossEntropy: 0.03675810620188713, Accuracy: 0.9871922953964194\n",
      "Iter 58 / 2000, Loss: 198833153.46875, CrossEntropy: 0.03715282306075096, Accuracy: 0.9869685102301791\n",
      "Iter 59 / 2000, Loss: 192255137.65625, CrossEntropy: 0.03454923629760742, Accuracy: 0.9879275895140666\n",
      "Iter 60 / 2000, Loss: 189484099.8125, CrossEntropy: 0.03336172550916672, Accuracy: 0.9883511828644501\n",
      "Iter 61 / 2000, Loss: 194147869.46875, CrossEntropy: 0.03522541746497154, Accuracy: 0.9879795396419437\n",
      "Iter 62 / 2000, Loss: 186063282.6875, CrossEntropy: 0.03197618946433067, Accuracy: 0.988531010230179\n",
      "Iter 63 / 2000, Loss: 186727042.3125, CrossEntropy: 0.032274212688207626, Accuracy: 0.988387148337596\n",
      "Iter 64 / 2000, Loss: 190555983.5625, CrossEntropy: 0.033769942820072174, Accuracy: 0.9887468030690537\n",
      "Iter 65 / 2000, Loss: 182750847.34375, CrossEntropy: 0.03065439499914646, Accuracy: 0.9888387148337596\n",
      "Iter 66 / 2000, Loss: 183892485.59375, CrossEntropy: 0.031075630336999893, Accuracy: 0.9890105498721228\n",
      "Iter 67 / 2000, Loss: 183391490.1875, CrossEntropy: 0.0308770090341568, Accuracy: 0.9893981777493607\n",
      "Iter 68 / 2000, Loss: 180516620.34375, CrossEntropy: 0.02980487048625946, Accuracy: 0.9897738171355498\n",
      "Iter 69 / 2000, Loss: 178651518.875, CrossEntropy: 0.028984438627958298, Accuracy: 0.9907368925831203\n",
      "Iter 70 / 2000, Loss: 180353660.875, CrossEntropy: 0.029636643826961517, Accuracy: 0.9896379475703325\n",
      "Iter 71 / 2000, Loss: 179384827.4375, CrossEntropy: 0.02923939749598503, Accuracy: 0.9898297634271099\n",
      "Iter 72 / 2000, Loss: 179129720.15625, CrossEntropy: 0.029136313125491142, Accuracy: 0.9894581202046037\n",
      "Iter 73 / 2000, Loss: 173565492.375, CrossEntropy: 0.02689475007355213, Accuracy: 0.9902173913043478\n",
      "Iter 74 / 2000, Loss: 176310039.0, CrossEntropy: 0.027992989867925644, Accuracy: 0.9900775255754476\n",
      "Iter 75 / 2000, Loss: 171991526.6875, CrossEntropy: 0.02627274952828884, Accuracy: 0.9912963554987213\n",
      "Iter 76 / 2000, Loss: 174694102.75, CrossEntropy: 0.027347715571522713, Accuracy: 0.99065297314578\n",
      "Iter 77 / 2000, Loss: 171327336.21875, CrossEntropy: 0.0259824488312006, Accuracy: 0.9908248081841433\n",
      "Iter 78 / 2000, Loss: 166906181.28125, CrossEntropy: 0.0241966862231493, Accuracy: 0.9914162404092072\n",
      "Iter 79 / 2000, Loss: 166291391.90625, CrossEntropy: 0.023934587836265564, Accuracy: 0.992127557544757\n",
      "Iter 80 / 2000, Loss: 170670350.15625, CrossEntropy: 0.025677140802145004, Accuracy: 0.9908487851662404\n",
      "Iter 81 / 2000, Loss: 166542739.78125, CrossEntropy: 0.02401430904865265, Accuracy: 0.9919876918158568\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iter 82 / 2000, Loss: 168586408.75, CrossEntropy: 0.024958312511444092, Accuracy: 0.9919437340153453\n",
      "Iter 83 / 2000, Loss: 169264285.0625, CrossEntropy: 0.025086740031838417, Accuracy: 0.9913882672634271\n",
      "Iter 84 / 2000, Loss: 171243213.1875, CrossEntropy: 0.025881066918373108, Accuracy: 0.9908367966751919\n",
      "Iter 85 / 2000, Loss: 159958355.28125, CrossEntropy: 0.021395662799477577, Accuracy: 0.9928228900255756\n",
      "Iter 86 / 2000, Loss: 168245660.71875, CrossEntropy: 0.024689437821507454, Accuracy: 0.991656010230179\n",
      "Iter 87 / 2000, Loss: 164590853.65625, CrossEntropy: 0.02321067452430725, Accuracy: 0.991775895140665\n",
      "Iter 88 / 2000, Loss: 160675408.875, CrossEntropy: 0.021617017686367035, Accuracy: 0.9930866368286445\n",
      "Iter 89 / 2000, Loss: 159952843.28125, CrossEntropy: 0.021366922184824944, Accuracy: 0.9923753196930947\n",
      "Iter 90 / 2000, Loss: 160407824.5, CrossEntropy: 0.021522538736462593, Accuracy: 0.9921435421994885\n",
      "Iter 91 / 2000, Loss: 160553526.1875, CrossEntropy: 0.021535931155085564, Accuracy: 0.9930067135549873\n",
      "Iter 92 / 2000, Loss: 160019387.46875, CrossEntropy: 0.02131759002804756, Accuracy: 0.9923873081841432\n",
      "Iter 93 / 2000, Loss: 161719941.59375, CrossEntropy: 0.021983832120895386, Accuracy: 0.9921075767263428\n",
      "Iter 94 / 2000, Loss: 158577301.125, CrossEntropy: 0.020737284794449806, Accuracy: 0.9931745524296676\n",
      "Iter 95 / 2000, Loss: 159614659.65625, CrossEntropy: 0.02112733945250511, Accuracy: 0.9927269820971867\n",
      "Iter 96 / 2000, Loss: 161451520.875, CrossEntropy: 0.02185889333486557, Accuracy: 0.9923673273657289\n",
      "Iter 97 / 2000, Loss: 157610520.6875, CrossEntropy: 0.020314356312155724, Accuracy: 0.9928868286445013\n",
      "Iter 98 / 2000, Loss: 158579066.0625, CrossEntropy: 0.020708873867988586, Accuracy: 0.9929947250639386\n",
      "Iter 99 / 2000, Loss: 156122238.25, CrossEntropy: 0.01971273496747017, Accuracy: 0.9933144181585678\n",
      "Iter 100 / 2000, Loss: 159476406.6875, CrossEntropy: 0.021083667874336243, Accuracy: 0.9924112851662403\n",
      "Iter 101 / 2000, Loss: 163082968.96875, CrossEntropy: 0.022465113550424576, Accuracy: 0.9924872122762148\n",
      "Iter 102 / 2000, Loss: 150947828.65625, CrossEntropy: 0.017622344195842743, Accuracy: 0.9939538043478261\n",
      "Iter 103 / 2000, Loss: 155324399.90625, CrossEntropy: 0.019383499398827553, Accuracy: 0.9926550511508951\n",
      "Iter 104 / 2000, Loss: 156139157.84375, CrossEntropy: 0.019810019060969353, Accuracy: 0.993574168797954\n",
      "Iter 105 / 2000, Loss: 154612425.1875, CrossEntropy: 0.019088666886091232, Accuracy: 0.9934502877237851\n",
      "Iter 106 / 2000, Loss: 150112505.34375, CrossEntropy: 0.017330896109342575, Accuracy: 0.9941016624040921\n",
      "Iter 107 / 2000, Loss: 155349765.40625, CrossEntropy: 0.019333994016051292, Accuracy: 0.9928748401534527\n",
      "Iter 108 / 2000, Loss: 159111215.4375, CrossEntropy: 0.020822659134864807, Accuracy: 0.9926670396419437\n",
      "Iter 109 / 2000, Loss: 153309403.9375, CrossEntropy: 0.01849844865500927, Accuracy: 0.9938339194373402\n",
      "Iter 110 / 2000, Loss: 149517362.53125, CrossEntropy: 0.017004815861582756, Accuracy: 0.9943534207161125\n",
      "Iter 111 / 2000, Loss: 154076767.875, CrossEntropy: 0.018779633566737175, Accuracy: 0.993366368286445\n",
      "Iter 112 / 2000, Loss: 153397540.25, CrossEntropy: 0.018509162589907646, Accuracy: 0.9937539961636829\n",
      "Iter 113 / 2000, Loss: 145852585.875, CrossEntropy: 0.01550188846886158, Accuracy: 0.9947130754475704\n",
      "Iter 114 / 2000, Loss: 157003402.15625, CrossEntropy: 0.019927771762013435, Accuracy: 0.993486253196931\n",
      "Iter 115 / 2000, Loss: 149658797.65625, CrossEntropy: 0.016999509185552597, Accuracy: 0.9940537084398977\n",
      "Iter 116 / 2000, Loss: 150741124.875, CrossEntropy: 0.017405247315764427, Accuracy: 0.9937460038363172\n",
      "Iter 117 / 2000, Loss: 147483546.53125, CrossEntropy: 0.016207478940486908, Accuracy: 0.994341432225064\n",
      "Iter 118 / 2000, Loss: 154708115.8125, CrossEntropy: 0.018989894539117813, Accuracy: 0.9934343030690538\n",
      "Iter 119 / 2000, Loss: 147167708.21875, CrossEntropy: 0.015955226495862007, Accuracy: 0.9944653132992327\n",
      "Iter 120 / 2000, Loss: 147247517.59375, CrossEntropy: 0.016013948246836662, Accuracy: 0.9945412404092072\n",
      "Iter 121 / 2000, Loss: 145474461.03125, CrossEntropy: 0.015328054316341877, Accuracy: 0.9945212595907928\n",
      "Iter 122 / 2000, Loss: 146395581.75, CrossEntropy: 0.01566162519156933, Accuracy: 0.9948809143222507\n",
      "Iter 123 / 2000, Loss: 147949625.0, CrossEntropy: 0.016233021393418312, Accuracy: 0.9944653132992327\n",
      "Iter 124 / 2000, Loss: 145599289.53125, CrossEntropy: 0.015285981819033623, Accuracy: 0.9948249680306905\n",
      "Iter 125 / 2000, Loss: 146183896.6875, CrossEntropy: 0.015557708218693733, Accuracy: 0.9949328644501279\n",
      "Iter 126 / 2000, Loss: 148739943.59375, CrossEntropy: 0.016528304666280746, Accuracy: 0.9942255434782609\n",
      "Iter 127 / 2000, Loss: 152841014.46875, CrossEntropy: 0.018243614584207535, Accuracy: 0.9938898657289001\n",
      "Iter 128 / 2000, Loss: 149608281.8125, CrossEntropy: 0.016890184953808784, Accuracy: 0.9945732097186701\n",
      "Iter 129 / 2000, Loss: 143877636.84375, CrossEntropy: 0.014556952752172947, Accuracy: 0.9950647378516624\n",
      "Iter 130 / 2000, Loss: 147533068.1875, CrossEntropy: 0.016019532456994057, Accuracy: 0.9947250639386189\n",
      "Iter 131 / 2000, Loss: 144075897.9375, CrossEntropy: 0.014633812941610813, Accuracy: 0.9952046035805626\n",
      "Iter 132 / 2000, Loss: 148351011.46875, CrossEntropy: 0.016372699290513992, Accuracy: 0.9944413363171356\n",
      "Iter 133 / 2000, Loss: 143642328.6875, CrossEntropy: 0.014431379735469818, Accuracy: 0.9948649296675192\n",
      "Iter 134 / 2000, Loss: 140481572.6875, CrossEntropy: 0.013177091255784035, Accuracy: 0.9954124040920717\n",
      "Iter 135 / 2000, Loss: 143610008.8125, CrossEntropy: 0.0144016919657588, Accuracy: 0.9949848145780051\n",
      "Iter 136 / 2000, Loss: 145250623.5, CrossEntropy: 0.015056449919939041, Accuracy: 0.9949248721227621\n",
      "Iter 137 / 2000, Loss: 145515297.46875, CrossEntropy: 0.015149212442338467, Accuracy: 0.9945851982097187\n",
      "Iter 138 / 2000, Loss: 141334525.375, CrossEntropy: 0.013483921997249126, Accuracy: 0.9951526534526854\n",
      "Iter 139 / 2000, Loss: 146089546.0625, CrossEntropy: 0.015360679477453232, Accuracy: 0.9948649296675192\n",
      "Iter 140 / 2000, Loss: 139092694.96875, CrossEntropy: 0.01255567092448473, Accuracy: 0.9957041240409207\n",
      "Iter 141 / 2000, Loss: 139719072.40625, CrossEntropy: 0.012801817618310452, Accuracy: 0.9955642583120204\n",
      "Iter 142 / 2000, Loss: 147746877.90625, CrossEntropy: 0.01600438728928566, Accuracy: 0.9947450447570333\n",
      "Iter 143 / 2000, Loss: 144342957.78125, CrossEntropy: 0.014631330966949463, Accuracy: 0.9949648337595908\n",
      "Iter 144 / 2000, Loss: 143555318.46875, CrossEntropy: 0.014308251440525055, Accuracy: 0.9953244884910486\n",
      "Iter 145 / 2000, Loss: 146936490.46875, CrossEntropy: 0.015653390437364578, Accuracy: 0.994605179028133\n",
      "Iter 146 / 2000, Loss: 139840202.65625, CrossEntropy: 0.012836961075663567, Accuracy: 0.9956521739130435\n",
      "Iter 147 / 2000, Loss: 143966467.9375, CrossEntropy: 0.014448009431362152, Accuracy: 0.995104699488491\n",
      "Iter 148 / 2000, Loss: 138856612.4375, CrossEntropy: 0.012413926422595978, Accuracy: 0.9956521739130435\n",
      "Iter 149 / 2000, Loss: 139272193.84375, CrossEntropy: 0.01256842352449894, Accuracy: 0.9957320971867007\n",
      "Iter 150 / 2000, Loss: 145130167.5625, CrossEntropy: 0.014891012571752071, Accuracy: 0.9947050831202046\n",
      "Iter 151 / 2000, Loss: 142960253.4375, CrossEntropy: 0.01402549259364605, Accuracy: 0.9952725383631714\n",
      "Iter 152 / 2000, Loss: 140348178.6875, CrossEntropy: 0.01300724782049656, Accuracy: 0.9955123081841433\n",
      "Iter 153 / 2000, Loss: 137226566.21875, CrossEntropy: 0.011719128116965294, Accuracy: 0.9962116368286446\n",
      "Iter 154 / 2000, Loss: 145496924.875, CrossEntropy: 0.015052126720547676, Accuracy: 0.9951806265984655\n",
      "Iter 155 / 2000, Loss: 136494425.0, CrossEntropy: 0.011403977870941162, Accuracy: 0.9960437979539642\n",
      "Iter 156 / 2000, Loss: 137060415.96875, CrossEntropy: 0.011621015146374702, Accuracy: 0.9958639705882353\n",
      "Iter 157 / 2000, Loss: 140625861.15625, CrossEntropy: 0.0130338529124856, Accuracy: 0.9955242966751918\n",
      "Iter 158 / 2000, Loss: 138792811.125, CrossEntropy: 0.01229212898761034, Accuracy: 0.9956641624040921\n",
      "Iter 159 / 2000, Loss: 141450090.6875, CrossEntropy: 0.013403935357928276, Accuracy: 0.9954923273657289\n",
      "Iter 160 / 2000, Loss: 139423381.0, CrossEntropy: 0.0125492038205266, Accuracy: 0.9956321930946291\n",
      "Iter 161 / 2000, Loss: 140896405.15625, CrossEntropy: 0.013127611018717289, Accuracy: 0.9953524616368287\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iter 162 / 2000, Loss: 136427410.90625, CrossEntropy: 0.011319207958877087, Accuracy: 0.9962835677749361\n",
      "Iter 163 / 2000, Loss: 137939950.4375, CrossEntropy: 0.011977321468293667, Accuracy: 0.9959518861892583\n",
      "Iter 164 / 2000, Loss: 140655295.53125, CrossEntropy: 0.013050086796283722, Accuracy: 0.9955402813299233\n",
      "Iter 165 / 2000, Loss: 140412581.5, CrossEntropy: 0.012890801765024662, Accuracy: 0.9955442774936062\n",
      "Iter 166 / 2000, Loss: 134569155.625, CrossEntropy: 0.010541852563619614, Accuracy: 0.9965233375959079\n",
      "Iter 167 / 2000, Loss: 141762778.125, CrossEntropy: 0.013410837389528751, Accuracy: 0.9953244884910486\n",
      "Iter 168 / 2000, Loss: 137768569.90625, CrossEntropy: 0.011805448681116104, Accuracy: 0.9959638746803069\n",
      "Iter 169 / 2000, Loss: 140961400.53125, CrossEntropy: 0.013084802776575089, Accuracy: 0.9955242966751918\n",
      "Iter 170 / 2000, Loss: 137009133.375, CrossEntropy: 0.01148665975779295, Accuracy: 0.9960837595907929\n",
      "Iter 171 / 2000, Loss: 140005688.3125, CrossEntropy: 0.012676732614636421, Accuracy: 0.9958639705882353\n",
      "Iter 172 / 2000, Loss: 138703027.8125, CrossEntropy: 0.01215207390487194, Accuracy: 0.9957440856777494\n",
      "Iter 173 / 2000, Loss: 139263664.0625, CrossEntropy: 0.012461469508707523, Accuracy: 0.9959598785166242\n",
      "Iter 174 / 2000, Loss: 135133365.1875, CrossEntropy: 0.010710545815527439, Accuracy: 0.9963235294117647\n",
      "Iter 175 / 2000, Loss: 139561089.125, CrossEntropy: 0.012510831467807293, Accuracy: 0.996019820971867\n",
      "Iter 176 / 2000, Loss: 134447836.53125, CrossEntropy: 0.01041528582572937, Accuracy: 0.9965233375959079\n",
      "Iter 177 / 2000, Loss: 141509426.8125, CrossEntropy: 0.013230174779891968, Accuracy: 0.9951446611253197\n",
      "Iter 178 / 2000, Loss: 137642929.53125, CrossEntropy: 0.011709393933415413, Accuracy: 0.9963914641943734\n",
      "Iter 179 / 2000, Loss: 135566850.5625, CrossEntropy: 0.01084177941083908, Accuracy: 0.9962436061381074\n",
      "Iter 180 / 2000, Loss: 138510292.28125, CrossEntropy: 0.012018892914056778, Accuracy: 0.9961237212276215\n",
      "Iter 181 / 2000, Loss: 138558006.34375, CrossEntropy: 0.012022176757454872, Accuracy: 0.995784047314578\n",
      "Iter 182 / 2000, Loss: 135452904.90625, CrossEntropy: 0.01077079027891159, Accuracy: 0.9964434143222506\n",
      "Iter 183 / 2000, Loss: 138968669.0, CrossEntropy: 0.012168423272669315, Accuracy: 0.995724104859335\n",
      "Iter 184 / 2000, Loss: 136917760.40625, CrossEntropy: 0.011340431869029999, Accuracy: 0.9959838554987213\n",
      "Iter 185 / 2000, Loss: 132931798.71875, CrossEntropy: 0.009765677154064178, Accuracy: 0.996551310741688\n",
      "Iter 186 / 2000, Loss: 135962105.5, CrossEntropy: 0.010942833498120308, Accuracy: 0.9962436061381074\n",
      "Iter 187 / 2000, Loss: 136421733.4375, CrossEntropy: 0.011118809692561626, Accuracy: 0.9960837595907929\n",
      "Iter 188 / 2000, Loss: 136604842.96875, CrossEntropy: 0.011186009272933006, Accuracy: 0.996343510230179\n",
      "Iter 189 / 2000, Loss: 134917019.84375, CrossEntropy: 0.010580887086689472, Accuracy: 0.9964394181585678\n",
      "Iter 190 / 2000, Loss: 133529106.75, CrossEntropy: 0.009942086413502693, Accuracy: 0.996403452685422\n",
      "Iter 191 / 2000, Loss: 135677541.5625, CrossEntropy: 0.010806496255099773, Accuracy: 0.9962715792838875\n",
      "Iter 192 / 2000, Loss: 132655696.75, CrossEntropy: 0.009625907056033611, Accuracy: 0.9969709079283888\n",
      "Iter 193 / 2000, Loss: 135385290.90625, CrossEntropy: 0.010666667483747005, Accuracy: 0.9962116368286446\n",
      "Iter 194 / 2000, Loss: 136411496.34375, CrossEntropy: 0.011116821318864822, Accuracy: 0.9963914641943734\n",
      "Iter 195 / 2000, Loss: 136837031.9375, CrossEntropy: 0.011337701231241226, Accuracy: 0.996199648337596\n",
      "Iter 196 / 2000, Loss: 136787696.0625, CrossEntropy: 0.011195863597095013, Accuracy: 0.9962436061381074\n",
      "Iter 197 / 2000, Loss: 132099566.0625, CrossEntropy: 0.009330390952527523, Accuracy: 0.9965113491048594\n",
      "Iter 198 / 2000, Loss: 130390922.65625, CrossEntropy: 0.008624162524938583, Accuracy: 0.996843030690537\n",
      "Iter 199 / 2000, Loss: 134754307.875, CrossEntropy: 0.01036027166992426, Accuracy: 0.9965233375959079\n",
      "Iter 200 / 2000, Loss: 133888747.875, CrossEntropy: 0.010048318654298782, Accuracy: 0.9965712915601024\n",
      "Iter 201 / 2000, Loss: 134886065.9375, CrossEntropy: 0.010435320436954498, Accuracy: 0.9961117327365729\n",
      "Iter 202 / 2000, Loss: 133415878.78125, CrossEntropy: 0.009804406203329563, Accuracy: 0.9967031649616368\n",
      "Iter 203 / 2000, Loss: 133866242.1875, CrossEntropy: 0.009980795904994011, Accuracy: 0.9965233375959079\n",
      "Iter 204 / 2000, Loss: 135331654.125, CrossEntropy: 0.01056216936558485, Accuracy: 0.9967631074168798\n",
      "Iter 205 / 2000, Loss: 135208633.9375, CrossEntropy: 0.010500328615307808, Accuracy: 0.9962835677749361\n",
      "Iter 206 / 2000, Loss: 134192475.40625, CrossEntropy: 0.010087798349559307, Accuracy: 0.9964434143222506\n",
      "Iter 207 / 2000, Loss: 132983755.1875, CrossEntropy: 0.00960585568100214, Accuracy: 0.9966232416879796\n",
      "Iter 208 / 2000, Loss: 131786416.09375, CrossEntropy: 0.009108497761189938, Accuracy: 0.9968630115089514\n",
      "Iter 209 / 2000, Loss: 130791484.46875, CrossEntropy: 0.008701789192855358, Accuracy: 0.9969429347826086\n",
      "Iter 210 / 2000, Loss: 138638029.875, CrossEntropy: 0.011828210204839706, Accuracy: 0.9957640664961637\n",
      "Iter 211 / 2000, Loss: 135979437.875, CrossEntropy: 0.010764792561531067, Accuracy: 0.9961636828644501\n",
      "Iter 212 / 2000, Loss: 130226364.21875, CrossEntropy: 0.00845804437994957, Accuracy: 0.9973225703324808\n",
      "Iter 213 / 2000, Loss: 127653461.75, CrossEntropy: 0.007454404607415199, Accuracy: 0.9974704283887468\n",
      "Iter 214 / 2000, Loss: 137571923.75, CrossEntropy: 0.011386062949895859, Accuracy: 0.9963714833759592\n",
      "Iter 215 / 2000, Loss: 132722795.0625, CrossEntropy: 0.00944314245134592, Accuracy: 0.9965912723785166\n",
      "Iter 216 / 2000, Loss: 130926767.09375, CrossEntropy: 0.008710058405995369, Accuracy: 0.9971427429667519\n",
      "Iter 217 / 2000, Loss: 133625584.53125, CrossEntropy: 0.009802055545151234, Accuracy: 0.9963994565217392\n",
      "Iter 218 / 2000, Loss: 132459333.09375, CrossEntropy: 0.009332072921097279, Accuracy: 0.9966911764705882\n",
      "Iter 219 / 2000, Loss: 131151657.0, CrossEntropy: 0.008782202377915382, Accuracy: 0.9968110613810742\n",
      "Iter 220 / 2000, Loss: 131241904.34375, CrossEntropy: 0.008800450712442398, Accuracy: 0.9973225703324808\n",
      "Iter 221 / 2000, Loss: 133439356.0625, CrossEntropy: 0.00966943334788084, Accuracy: 0.9969429347826086\n",
      "Iter 222 / 2000, Loss: 131283663.46875, CrossEntropy: 0.00880212988704443, Accuracy: 0.996962915601023\n",
      "Iter 223 / 2000, Loss: 133087103.625, CrossEntropy: 0.009515495039522648, Accuracy: 0.9969429347826086\n",
      "Iter 224 / 2000, Loss: 131704965.3125, CrossEntropy: 0.0089688366279006, Accuracy: 0.9967111572890026\n",
      "Iter 225 / 2000, Loss: 131643607.90625, CrossEntropy: 0.008982562460005283, Accuracy: 0.9969789002557545\n",
      "Iter 226 / 2000, Loss: 135785942.65625, CrossEntropy: 0.01057136058807373, Accuracy: 0.9965233375959079\n",
      "Iter 227 / 2000, Loss: 131037260.75, CrossEntropy: 0.008665253408253193, Accuracy: 0.9968230498721228\n",
      "Iter 228 / 2000, Loss: 134016346.1875, CrossEntropy: 0.009859468787908554, Accuracy: 0.9965233375959079\n",
      "Iter 229 / 2000, Loss: 132730311.5, CrossEntropy: 0.009331636130809784, Accuracy: 0.9967031649616368\n",
      "Iter 230 / 2000, Loss: 132961035.28125, CrossEntropy: 0.009438656270503998, Accuracy: 0.9969509271099745\n",
      "Iter 231 / 2000, Loss: 131556937.71875, CrossEntropy: 0.008853199891746044, Accuracy: 0.997170716112532\n",
      "Iter 232 / 2000, Loss: 132894284.25, CrossEntropy: 0.009392790496349335, Accuracy: 0.9966512148337596\n",
      "Iter 233 / 2000, Loss: 131902958.4375, CrossEntropy: 0.008992294780910015, Accuracy: 0.9969309462915601\n",
      "Iter 234 / 2000, Loss: 133155542.0625, CrossEntropy: 0.009461071342229843, Accuracy: 0.9966831841432225\n",
      "Iter 235 / 2000, Loss: 131282165.84375, CrossEntropy: 0.008715412579476833, Accuracy: 0.9968230498721228\n",
      "Iter 236 / 2000, Loss: 131784809.65625, CrossEntropy: 0.008897255174815655, Accuracy: 0.9967830882352942\n",
      "Iter 237 / 2000, Loss: 130975096.625, CrossEntropy: 0.008566506206989288, Accuracy: 0.9971427429667519\n",
      "Iter 238 / 2000, Loss: 129525397.28125, CrossEntropy: 0.007982020266354084, Accuracy: 0.9973425511508951\n",
      "Iter 239 / 2000, Loss: 135865117.21875, CrossEntropy: 0.010505572892725468, Accuracy: 0.9966632033248082\n",
      "Iter 240 / 2000, Loss: 130224707.4375, CrossEntropy: 0.008307392708957195, Accuracy: 0.9974904092071611\n",
      "Iter 241 / 2000, Loss: 137383267.1875, CrossEntropy: 0.011145581491291523, Accuracy: 0.9963315217391304\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iter 242 / 2000, Loss: 128604048.59375, CrossEntropy: 0.007582635153084993, Accuracy: 0.9971627237851662\n",
      "Iter 243 / 2000, Loss: 132898398.3125, CrossEntropy: 0.009293334558606148, Accuracy: 0.9968030690537084\n",
      "Iter 244 / 2000, Loss: 129726371.34375, CrossEntropy: 0.008019079454243183, Accuracy: 0.9973025895140665\n",
      "Iter 245 / 2000, Loss: 130193844.53125, CrossEntropy: 0.008195937611162663, Accuracy: 0.9974224744245525\n",
      "Iter 246 / 2000, Loss: 131542584.5625, CrossEntropy: 0.008727283217012882, Accuracy: 0.9970428388746803\n",
      "Iter 247 / 2000, Loss: 127440875.0, CrossEntropy: 0.0070918952114880085, Accuracy: 0.9974104859335039\n",
      "Iter 248 / 2000, Loss: 130284508.0, CrossEntropy: 0.008215007372200489, Accuracy: 0.9971027813299232\n",
      "Iter 249 / 2000, Loss: 129975403.8125, CrossEntropy: 0.008079557679593563, Accuracy: 0.9970028772378516\n",
      "Iter 250 / 2000, Loss: 136310063.4375, CrossEntropy: 0.010603840462863445, Accuracy: 0.9965233375959079\n",
      "Iter 251 / 2000, Loss: 132251594.21875, CrossEntropy: 0.008975251577794552, Accuracy: 0.9970028772378516\n",
      "Iter 252 / 2000, Loss: 129813003.03125, CrossEntropy: 0.008031606674194336, Accuracy: 0.9976502557544757\n",
      "Iter 253 / 2000, Loss: 132310985.90625, CrossEntropy: 0.008983864448964596, Accuracy: 0.9971427429667519\n",
      "Iter 254 / 2000, Loss: 132272605.15625, CrossEntropy: 0.008966203778982162, Accuracy: 0.997022858056266\n",
      "Iter 255 / 2000, Loss: 127669088.5, CrossEntropy: 0.007114870008081198, Accuracy: 0.9975423593350383\n",
      "Iter 256 / 2000, Loss: 129721472.71875, CrossEntropy: 0.007956111803650856, Accuracy: 0.9971507352941177\n",
      "Iter 257 / 2000, Loss: 132670391.125, CrossEntropy: 0.009098924696445465, Accuracy: 0.99690297314578\n",
      "Iter 258 / 2000, Loss: 127933835.0, CrossEntropy: 0.007207789923995733, Accuracy: 0.9975423593350383\n",
      "Iter 259 / 2000, Loss: 133255538.96875, CrossEntropy: 0.009319235570728779, Accuracy: 0.9971227621483376\n",
      "Iter 260 / 2000, Loss: 131448438.46875, CrossEntropy: 0.008592171594500542, Accuracy: 0.9973025895140665\n",
      "Iter 261 / 2000, Loss: 126060786.15625, CrossEntropy: 0.006442260928452015, Accuracy: 0.9977101982097187\n",
      "Iter 262 / 2000, Loss: 131554655.5, CrossEntropy: 0.008630050346255302, Accuracy: 0.9970708120204604\n",
      "Iter 263 / 2000, Loss: 130426224.65625, CrossEntropy: 0.008159033954143524, Accuracy: 0.9974624360613811\n",
      "Iter 264 / 2000, Loss: 127876820.84375, CrossEntropy: 0.007133886683732271, Accuracy: 0.9975223785166241\n",
      "Iter 265 / 2000, Loss: 131339389.46875, CrossEntropy: 0.008533076383173466, Accuracy: 0.9972306585677749\n",
      "Iter 266 / 2000, Loss: 131775757.0, CrossEntropy: 0.008734269067645073, Accuracy: 0.9970907928388747\n",
      "Iter 267 / 2000, Loss: 131247844.3125, CrossEntropy: 0.008476008661091328, Accuracy: 0.9972706202046037\n",
      "Iter 268 / 2000, Loss: 131330189.03125, CrossEntropy: 0.008489742875099182, Accuracy: 0.9971827046035806\n",
      "Iter 269 / 2000, Loss: 130971675.6875, CrossEntropy: 0.008334112353622913, Accuracy: 0.9970428388746803\n",
      "Iter 270 / 2000, Loss: 131367280.4375, CrossEntropy: 0.008490498177707195, Accuracy: 0.9971027813299232\n",
      "Iter 271 / 2000, Loss: 128156624.96875, CrossEntropy: 0.007199466694146395, Accuracy: 0.9974824168797954\n",
      "Iter 272 / 2000, Loss: 132340753.5, CrossEntropy: 0.008859820663928986, Accuracy: 0.9973425511508951\n",
      "Iter 273 / 2000, Loss: 129378820.46875, CrossEntropy: 0.007738418877124786, Accuracy: 0.997346547314578\n",
      "Iter 274 / 2000, Loss: 130916508.1875, CrossEntropy: 0.00827798806130886, Accuracy: 0.9969429347826086\n",
      "Iter 275 / 2000, Loss: 129069279.25, CrossEntropy: 0.007530746050179005, Accuracy: 0.9973825127877238\n",
      "Iter 276 / 2000, Loss: 129654696.71875, CrossEntropy: 0.007758865598589182, Accuracy: 0.9973625319693095\n",
      "Iter 277 / 2000, Loss: 130881203.46875, CrossEntropy: 0.008243193849921227, Accuracy: 0.9974224744245525\n",
      "Iter 278 / 2000, Loss: 127130847.59375, CrossEntropy: 0.006744798738509417, Accuracy: 0.99767023657289\n",
      "Iter 279 / 2000, Loss: 130773597.9375, CrossEntropy: 0.008194190450012684, Accuracy: 0.9974304667519182\n",
      "Iter 280 / 2000, Loss: 129369392.84375, CrossEntropy: 0.007617687340825796, Accuracy: 0.997582320971867\n",
      "Iter 281 / 2000, Loss: 129493041.84375, CrossEntropy: 0.0076630376279354095, Accuracy: 0.9973225703324808\n",
      "Iter 282 / 2000, Loss: 125828671.84375, CrossEntropy: 0.00626468425616622, Accuracy: 0.9976982097186702\n",
      "Iter 283 / 2000, Loss: 129181446.875, CrossEntropy: 0.007523904088884592, Accuracy: 0.9974424552429667\n",
      "Iter 284 / 2000, Loss: 128241998.78125, CrossEntropy: 0.0071551683358848095, Accuracy: 0.9974904092071611\n",
      "Iter 285 / 2000, Loss: 128179572.25, CrossEntropy: 0.007115135435014963, Accuracy: 0.9974424552429667\n",
      "Iter 286 / 2000, Loss: 130370430.75, CrossEntropy: 0.007976065389811993, Accuracy: 0.9972626278772379\n",
      "Iter 287 / 2000, Loss: 127033877.21875, CrossEntropy: 0.006634848657995462, Accuracy: 0.9977621483375959\n",
      "Iter 288 / 2000, Loss: 133598774.65625, CrossEntropy: 0.009250480681657791, Accuracy: 0.9968030690537084\n",
      "Iter 289 / 2000, Loss: 129910377.1875, CrossEntropy: 0.00777764618396759, Accuracy: 0.9974304667519182\n",
      "Iter 290 / 2000, Loss: 127254485.5625, CrossEntropy: 0.0067012375220656395, Accuracy: 0.9976422634271099\n",
      "Iter 291 / 2000, Loss: 129127665.71875, CrossEntropy: 0.007441869005560875, Accuracy: 0.9972826086956522\n",
      "Iter 292 / 2000, Loss: 129077636.46875, CrossEntropy: 0.0074409376829862595, Accuracy: 0.9975103900255755\n",
      "Iter 293 / 2000, Loss: 128462955.28125, CrossEntropy: 0.007163264788687229, Accuracy: 0.9972826086956522\n",
      "Iter 294 / 2000, Loss: 128194080.8125, CrossEntropy: 0.0070818825624883175, Accuracy: 0.9974504475703325\n",
      "Iter 295 / 2000, Loss: 128738287.4375, CrossEntropy: 0.007260901387780905, Accuracy: 0.9975223785166241\n",
      "Iter 296 / 2000, Loss: 128452136.6875, CrossEntropy: 0.007137516047805548, Accuracy: 0.9976822250639387\n",
      "Iter 297 / 2000, Loss: 129710416.875, CrossEntropy: 0.00763341598212719, Accuracy: 0.9974824168797954\n",
      "Iter 298 / 2000, Loss: 130554796.625, CrossEntropy: 0.00796777568757534, Accuracy: 0.9973025895140665\n",
      "Iter 299 / 2000, Loss: 128969325.46875, CrossEntropy: 0.007323357742279768, Accuracy: 0.9973425511508951\n",
      "Iter 300 / 2000, Loss: 130724935.625, CrossEntropy: 0.008023886941373348, Accuracy: 0.9973425511508951\n",
      "Iter 301 / 2000, Loss: 129497493.90625, CrossEntropy: 0.007572066504508257, Accuracy: 0.9974704283887468\n",
      "Iter 302 / 2000, Loss: 126885115.34375, CrossEntropy: 0.006470297928899527, Accuracy: 0.9979419757033248\n",
      "Iter 303 / 2000, Loss: 131799356.21875, CrossEntropy: 0.008426998741924763, Accuracy: 0.9972426470588235\n",
      "Iter 304 / 2000, Loss: 128163706.71875, CrossEntropy: 0.006967597641050816, Accuracy: 0.9977621483375959\n",
      "Iter 305 / 2000, Loss: 129870853.125, CrossEntropy: 0.0076435524970293045, Accuracy: 0.9972426470588235\n",
      "Iter 306 / 2000, Loss: 129108268.34375, CrossEntropy: 0.007331651169806719, Accuracy: 0.9973625319693095\n",
      "Iter 307 / 2000, Loss: 128354209.1875, CrossEntropy: 0.007030772976577282, Accuracy: 0.9974424552429667\n",
      "Iter 308 / 2000, Loss: 125383433.53125, CrossEntropy: 0.005830653943121433, Accuracy: 0.9982217071611253\n",
      "Iter 309 / 2000, Loss: 133815760.0, CrossEntropy: 0.009191598743200302, Accuracy: 0.9966831841432225\n",
      "Iter 310 / 2000, Loss: 123388628.5, CrossEntropy: 0.00501934252679348, Accuracy: 0.998141783887468\n",
      "Iter 311 / 2000, Loss: 132529295.15625, CrossEntropy: 0.008664455264806747, Accuracy: 0.9970428388746803\n",
      "Iter 312 / 2000, Loss: 127752516.8125, CrossEntropy: 0.006769981700927019, Accuracy: 0.9976302749360614\n",
      "Iter 313 / 2000, Loss: 127086129.25, CrossEntropy: 0.006474795285612345, Accuracy: 0.9978620524296675\n",
      "Iter 314 / 2000, Loss: 129653719.5625, CrossEntropy: 0.007493560668081045, Accuracy: 0.9971627237851662\n",
      "Iter 315 / 2000, Loss: 126884878.15625, CrossEntropy: 0.006381618790328503, Accuracy: 0.9978620524296675\n",
      "Iter 316 / 2000, Loss: 129997779.8125, CrossEntropy: 0.007629916537553072, Accuracy: 0.9971507352941177\n",
      "Iter 317 / 2000, Loss: 129968735.59375, CrossEntropy: 0.007600716780871153, Accuracy: 0.9974424552429667\n",
      "Iter 318 / 2000, Loss: 127163569.3125, CrossEntropy: 0.006471011321991682, Accuracy: 0.9978021099744245\n",
      "Iter 319 / 2000, Loss: 126003277.5, CrossEntropy: 0.00601185392588377, Accuracy: 0.997910006393862\n",
      "Iter 320 / 2000, Loss: 131448725.5625, CrossEntropy: 0.008170738816261292, Accuracy: 0.9974024936061381\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iter 321 / 2000, Loss: 130859633.96875, CrossEntropy: 0.007930709980428219, Accuracy: 0.9973625319693095\n",
      "Iter 322 / 2000, Loss: 128231470.53125, CrossEntropy: 0.006896582897752523, Accuracy: 0.9979899296675192\n",
      "Iter 323 / 2000, Loss: 132242568.8125, CrossEntropy: 0.008471230044960976, Accuracy: 0.9970428388746803\n",
      "Iter 324 / 2000, Loss: 127625662.125, CrossEntropy: 0.006617938168346882, Accuracy: 0.9977821291560103\n",
      "Iter 325 / 2000, Loss: 126438353.25, CrossEntropy: 0.006162535399198532, Accuracy: 0.997850063938619\n",
      "Iter 326 / 2000, Loss: 129108961.65625, CrossEntropy: 0.007197102066129446, Accuracy: 0.9975623401534527\n",
      "Iter 327 / 2000, Loss: 129440657.96875, CrossEntropy: 0.0073557449504733086, Accuracy: 0.9974704283887468\n",
      "Iter 328 / 2000, Loss: 128284576.0625, CrossEntropy: 0.006851927377283573, Accuracy: 0.9976023017902813\n",
      "Iter 329 / 2000, Loss: 129962617.5, CrossEntropy: 0.007515594828873873, Accuracy: 0.9973625319693095\n",
      "Iter 330 / 2000, Loss: 127776209.625, CrossEntropy: 0.006680395919829607, Accuracy: 0.9976902173913044\n",
      "Iter 331 / 2000, Loss: 127163091.15625, CrossEntropy: 0.006383445113897324, Accuracy: 0.9980418797953964\n",
      "Iter 332 / 2000, Loss: 127493023.75, CrossEntropy: 0.006509591359645128, Accuracy: 0.9977621483375959\n",
      "Iter 333 / 2000, Loss: 127772164.25, CrossEntropy: 0.006620257161557674, Accuracy: 0.9977821291560103\n",
      "Iter 334 / 2000, Loss: 127661318.0, CrossEntropy: 0.006568091455847025, Accuracy: 0.9977221867007673\n",
      "Iter 335 / 2000, Loss: 127211512.75, CrossEntropy: 0.006376117002218962, Accuracy: 0.9977421675191815\n",
      "Iter 336 / 2000, Loss: 125013402.625, CrossEntropy: 0.005491812247782946, Accuracy: 0.9980418797953964\n",
      "Iter 337 / 2000, Loss: 129295525.875, CrossEntropy: 0.00719731580466032, Accuracy: 0.997582320971867\n",
      "Iter 338 / 2000, Loss: 126698713.4375, CrossEntropy: 0.006158681586384773, Accuracy: 0.9979419757033248\n",
      "Iter 339 / 2000, Loss: 125316861.53125, CrossEntropy: 0.005592662841081619, Accuracy: 0.9980618606138107\n",
      "Iter 340 / 2000, Loss: 130608655.625, CrossEntropy: 0.007700114976614714, Accuracy: 0.9974424552429667\n",
      "Iter 341 / 2000, Loss: 131133952.5, CrossEntropy: 0.007903864607214928, Accuracy: 0.9974424552429667\n",
      "Iter 342 / 2000, Loss: 129787993.34375, CrossEntropy: 0.007395951077342033, Accuracy: 0.9974904092071611\n",
      "Iter 343 / 2000, Loss: 125193387.46875, CrossEntropy: 0.005543277133256197, Accuracy: 0.9979699488491048\n",
      "Iter 344 / 2000, Loss: 128637652.0, CrossEntropy: 0.006886990275233984, Accuracy: 0.9977621483375959\n",
      "Iter 345 / 2000, Loss: 124987038.84375, CrossEntropy: 0.005428762175142765, Accuracy: 0.9979619565217391\n",
      "Iter 346 / 2000, Loss: 125177814.3125, CrossEntropy: 0.005491477902978659, Accuracy: 0.9979819373401535\n",
      "Iter 347 / 2000, Loss: 129812733.28125, CrossEntropy: 0.007337093353271484, Accuracy: 0.9978220907928389\n",
      "Iter 348 / 2000, Loss: 127372135.25, CrossEntropy: 0.006357555277645588, Accuracy: 0.9979819373401535\n",
      "Iter 349 / 2000, Loss: 127562830.3125, CrossEntropy: 0.006444063503295183, Accuracy: 0.9979699488491048\n",
      "Iter 350 / 2000, Loss: 123693773.03125, CrossEntropy: 0.0048728929832577705, Accuracy: 0.998261668797954\n",
      "Iter 351 / 2000, Loss: 128041334.96875, CrossEntropy: 0.0066537694074213505, Accuracy: 0.9980698529411764\n",
      "Iter 352 / 2000, Loss: 122968352.40625, CrossEntropy: 0.004573511891067028, Accuracy: 0.9984614769820972\n",
      "Iter 353 / 2000, Loss: 129970066.46875, CrossEntropy: 0.007361201569437981, Accuracy: 0.9977621483375959\n",
      "Iter 354 / 2000, Loss: 126444031.59375, CrossEntropy: 0.005947357043623924, Accuracy: 0.9981218030690537\n",
      "Iter 355 / 2000, Loss: 127727790.0625, CrossEntropy: 0.0064578489400446415, Accuracy: 0.9977621483375959\n",
      "Iter 356 / 2000, Loss: 126603523.59375, CrossEntropy: 0.005996836349368095, Accuracy: 0.9980218989769821\n",
      "Iter 357 / 2000, Loss: 128402994.90625, CrossEntropy: 0.006709742825478315, Accuracy: 0.9979020140664961\n",
      "Iter 358 / 2000, Loss: 125632306.65625, CrossEntropy: 0.005595613270998001, Accuracy: 0.9980019181585678\n",
      "Iter 359 / 2000, Loss: 127870997.15625, CrossEntropy: 0.00648548174649477, Accuracy: 0.9978420716112532\n",
      "Iter 360 / 2000, Loss: 132037601.96875, CrossEntropy: 0.008142208680510521, Accuracy: 0.9975623401534527\n",
      "Iter 361 / 2000, Loss: 125708020.90625, CrossEntropy: 0.005608820356428623, Accuracy: 0.9980218989769821\n",
      "Iter 362 / 2000, Loss: 127100687.375, CrossEntropy: 0.00616062618792057, Accuracy: 0.9976422634271099\n",
      "Iter 363 / 2000, Loss: 128849849.5, CrossEntropy: 0.0068589914590120316, Accuracy: 0.9973905051150895\n",
      "Iter 364 / 2000, Loss: 125527339.5625, CrossEntropy: 0.0055216639302670956, Accuracy: 0.998201726342711\n",
      "Iter 365 / 2000, Loss: 125933710.0, CrossEntropy: 0.0056719351559877396, Accuracy: 0.9979819373401535\n",
      "Iter 366 / 2000, Loss: 130911219.9375, CrossEntropy: 0.00765463151037693, Accuracy: 0.9973625319693095\n",
      "Iter 367 / 2000, Loss: 126137890.21875, CrossEntropy: 0.005738462787121534, Accuracy: 0.9980418797953964\n",
      "Iter 368 / 2000, Loss: 126418764.6875, CrossEntropy: 0.005846786312758923, Accuracy: 0.9979020140664961\n",
      "Iter 369 / 2000, Loss: 127850072.875, CrossEntropy: 0.00640924321487546, Accuracy: 0.9979419757033248\n",
      "Iter 370 / 2000, Loss: 127091215.46875, CrossEntropy: 0.006100666709244251, Accuracy: 0.9980618606138107\n",
      "Iter 371 / 2000, Loss: 124490975.75, CrossEntropy: 0.005098753143101931, Accuracy: 0.9982496803069054\n",
      "Iter 372 / 2000, Loss: 131361746.03125, CrossEntropy: 0.007793993689119816, Accuracy: 0.9972226662404092\n",
      "Iter 373 / 2000, Loss: 126333243.8125, CrossEntropy: 0.005779611878097057, Accuracy: 0.9980818414322251\n",
      "Iter 374 / 2000, Loss: 127467086.875, CrossEntropy: 0.006227046251296997, Accuracy: 0.9980019181585678\n",
      "Iter 375 / 2000, Loss: 125101332.125, CrossEntropy: 0.005273923743516207, Accuracy: 0.9983615728900256\n",
      "Iter 376 / 2000, Loss: 127140101.34375, CrossEntropy: 0.006085701286792755, Accuracy: 0.9979020140664961\n",
      "Iter 377 / 2000, Loss: 128273425.375, CrossEntropy: 0.006535853259265423, Accuracy: 0.9977821291560103\n",
      "Iter 378 / 2000, Loss: 128090853.75, CrossEntropy: 0.006456972565501928, Accuracy: 0.9978021099744245\n",
      "Iter 379 / 2000, Loss: 127445568.75, CrossEntropy: 0.006185638252645731, Accuracy: 0.9979619565217391\n",
      "Iter 380 / 2000, Loss: 126830741.65625, CrossEntropy: 0.005936021450906992, Accuracy: 0.9980418797953964\n",
      "Iter 381 / 2000, Loss: 128160715.125, CrossEntropy: 0.006467684172093868, Accuracy: 0.9978101023017903\n",
      "Iter 382 / 2000, Loss: 125463274.28125, CrossEntropy: 0.005375365726649761, Accuracy: 0.9980019181585678\n",
      "Iter 383 / 2000, Loss: 126137737.875, CrossEntropy: 0.0056482465006411076, Accuracy: 0.9983096227621484\n",
      "Iter 384 / 2000, Loss: 124714457.0625, CrossEntropy: 0.005062204319983721, Accuracy: 0.998201726342711\n",
      "Iter 385 / 2000, Loss: 129090018.03125, CrossEntropy: 0.0068040345795452595, Accuracy: 0.9979419757033248\n",
      "Iter 386 / 2000, Loss: 125503598.90625, CrossEntropy: 0.005366983357816935, Accuracy: 0.998261668797954\n",
      "Iter 387 / 2000, Loss: 123546665.46875, CrossEntropy: 0.004578000865876675, Accuracy: 0.9985414002557544\n",
      "Iter 388 / 2000, Loss: 127023905.28125, CrossEntropy: 0.005975763313472271, Accuracy: 0.9978101023017903\n",
      "Iter 389 / 2000, Loss: 127786056.0, CrossEntropy: 0.00625821016728878, Accuracy: 0.9978420716112532\n",
      "Iter 390 / 2000, Loss: 127955768.8125, CrossEntropy: 0.0063196211121976376, Accuracy: 0.9980818414322251\n",
      "Iter 391 / 2000, Loss: 125629148.90625, CrossEntropy: 0.0053840880282223225, Accuracy: 0.998261668797954\n",
      "Iter 392 / 2000, Loss: 125377012.09375, CrossEntropy: 0.005276893265545368, Accuracy: 0.998201726342711\n",
      "Iter 393 / 2000, Loss: 128646667.0, CrossEntropy: 0.006578095257282257, Accuracy: 0.9980019181585678\n",
      "Iter 394 / 2000, Loss: 127881898.21875, CrossEntropy: 0.006268588360399008, Accuracy: 0.9979219948849105\n",
      "Iter 395 / 2000, Loss: 125509991.625, CrossEntropy: 0.005311685614287853, Accuracy: 0.9982816496163683\n",
      "Iter 396 / 2000, Loss: 126074204.5625, CrossEntropy: 0.005558012053370476, Accuracy: 0.9981497762148338\n",
      "Iter 397 / 2000, Loss: 126615266.46875, CrossEntropy: 0.005743423011153936, Accuracy: 0.9981817455242967\n",
      "Iter 398 / 2000, Loss: 128411304.1875, CrossEntropy: 0.006460574921220541, Accuracy: 0.99767023657289\n",
      "Iter 399 / 2000, Loss: 124439231.4375, CrossEntropy: 0.004858546424657106, Accuracy: 0.9983615728900256\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iter 400 / 2000, Loss: 127781992.65625, CrossEntropy: 0.006194037850946188, Accuracy: 0.9979020140664961\n",
      "Iter 401 / 2000, Loss: 128105938.5625, CrossEntropy: 0.0063124909065663815, Accuracy: 0.9978620524296675\n",
      "Iter 402 / 2000, Loss: 130404213.59375, CrossEntropy: 0.007223429158329964, Accuracy: 0.9978021099744245\n",
      "Iter 403 / 2000, Loss: 126600271.15625, CrossEntropy: 0.00569757167249918, Accuracy: 0.9979020140664961\n",
      "Iter 404 / 2000, Loss: 125301046.0625, CrossEntropy: 0.005174572113901377, Accuracy: 0.9981817455242967\n",
      "Iter 405 / 2000, Loss: 129672860.3125, CrossEntropy: 0.006911832373589277, Accuracy: 0.9978420716112532\n",
      "Iter 406 / 2000, Loss: 127436571.78125, CrossEntropy: 0.006068980321288109, Accuracy: 0.997786125319693\n",
      "Iter 407 / 2000, Loss: 123634032.3125, CrossEntropy: 0.004486407618969679, Accuracy: 0.9984215153452686\n",
      "Iter 408 / 2000, Loss: 127992333.25, CrossEntropy: 0.0062419516034424305, Accuracy: 0.997790121483376\n",
      "Iter 409 / 2000, Loss: 128600312.25, CrossEntropy: 0.006480933632701635, Accuracy: 0.997790121483376\n",
      "Iter 410 / 2000, Loss: 126300045.34375, CrossEntropy: 0.005544792395085096, Accuracy: 0.9981897378516624\n",
      "Iter 411 / 2000, Loss: 123992479.53125, CrossEntropy: 0.004608322866261005, Accuracy: 0.9985214194373402\n",
      "Iter 412 / 2000, Loss: 128529728.6875, CrossEntropy: 0.006432336755096912, Accuracy: 0.9978300831202046\n",
      "Iter 413 / 2000, Loss: 129262513.9375, CrossEntropy: 0.006701385602355003, Accuracy: 0.9977621483375959\n",
      "Iter 414 / 2000, Loss: 126450835.65625, CrossEntropy: 0.005568411201238632, Accuracy: 0.9981018222506394\n",
      "Iter 415 / 2000, Loss: 128791622.90625, CrossEntropy: 0.0065077836625278, Accuracy: 0.9978620524296675\n",
      "Iter 416 / 2000, Loss: 127018230.84375, CrossEntropy: 0.005792400799691677, Accuracy: 0.9981297953964194\n",
      "Iter 417 / 2000, Loss: 125416145.25, CrossEntropy: 0.005142144858837128, Accuracy: 0.9982217071611253\n",
      "Iter 418 / 2000, Loss: 125446749.40625, CrossEntropy: 0.0051437062211334705, Accuracy: 0.9983415920716112\n",
      "Iter 419 / 2000, Loss: 126440553.96875, CrossEntropy: 0.0055352505296468735, Accuracy: 0.9980618606138107\n",
      "Iter 420 / 2000, Loss: 128723000.65625, CrossEntropy: 0.006440300494432449, Accuracy: 0.9976622442455243\n",
      "Iter 421 / 2000, Loss: 126406379.25, CrossEntropy: 0.005509185139089823, Accuracy: 0.9981018222506394\n",
      "Iter 422 / 2000, Loss: 125672851.53125, CrossEntropy: 0.005212554708123207, Accuracy: 0.998201726342711\n",
      "Iter 423 / 2000, Loss: 125850090.25, CrossEntropy: 0.005305581726133823, Accuracy: 0.9984694693094629\n",
      "Iter 424 / 2000, Loss: 126711266.375, CrossEntropy: 0.00561309140175581, Accuracy: 0.9982217071611253\n",
      "Iter 425 / 2000, Loss: 126276714.46875, CrossEntropy: 0.0054330588318407536, Accuracy: 0.9980618606138107\n",
      "Iter 426 / 2000, Loss: 129201633.1875, CrossEntropy: 0.006594946142286062, Accuracy: 0.9977022058823529\n",
      "Iter 427 / 2000, Loss: 126977295.4375, CrossEntropy: 0.0056997667998075485, Accuracy: 0.9981018222506394\n",
      "Iter 428 / 2000, Loss: 125309304.40625, CrossEntropy: 0.005027463659644127, Accuracy: 0.9982816496163683\n",
      "Iter 429 / 2000, Loss: 128733990.6875, CrossEntropy: 0.006437913980334997, Accuracy: 0.9976382672634272\n",
      "Iter 430 / 2000, Loss: 125655240.875, CrossEntropy: 0.005195724777877331, Accuracy: 0.998229699488491\n",
      "Iter 431 / 2000, Loss: 123413508.59375, CrossEntropy: 0.004273793660104275, Accuracy: 0.9983695652173913\n",
      "Iter 432 / 2000, Loss: 123240935.625, CrossEntropy: 0.004177364055067301, Accuracy: 0.9983615728900256\n",
      "Iter 433 / 2000, Loss: 127208456.71875, CrossEntropy: 0.005757107399404049, Accuracy: 0.9980019181585678\n",
      "Iter 434 / 2000, Loss: 131901378.25, CrossEntropy: 0.007625129073858261, Accuracy: 0.9973425511508951\n",
      "Iter 435 / 2000, Loss: 125923426.21875, CrossEntropy: 0.005230070557445288, Accuracy: 0.9982217071611253\n",
      "Iter 436 / 2000, Loss: 129611240.8125, CrossEntropy: 0.006700677797198296, Accuracy: 0.9976822250639387\n",
      "Iter 437 / 2000, Loss: 129140153.96875, CrossEntropy: 0.0065038506872951984, Accuracy: 0.9980618606138107\n",
      "Iter 438 / 2000, Loss: 127278464.28125, CrossEntropy: 0.005761229898780584, Accuracy: 0.998201726342711\n",
      "Iter 439 / 2000, Loss: 127047897.875, CrossEntropy: 0.005657108034938574, Accuracy: 0.9980418797953964\n",
      "Iter 440 / 2000, Loss: 128125121.40625, CrossEntropy: 0.006079507060348988, Accuracy: 0.9981018222506394\n",
      "Iter 441 / 2000, Loss: 126801350.71875, CrossEntropy: 0.0055445111356675625, Accuracy: 0.9980218989769821\n",
      "Iter 442 / 2000, Loss: 128745690.84375, CrossEntropy: 0.006315300706773996, Accuracy: 0.9980218989769821\n",
      "Iter 443 / 2000, Loss: 125553991.4375, CrossEntropy: 0.005034224595874548, Accuracy: 0.99838155370844\n",
      "Iter 444 / 2000, Loss: 127481836.75, CrossEntropy: 0.005798774771392345, Accuracy: 0.9980418797953964\n",
      "Iter 445 / 2000, Loss: 124430905.6875, CrossEntropy: 0.004574157763272524, Accuracy: 0.9982816496163683\n",
      "Iter 446 / 2000, Loss: 125298005.90625, CrossEntropy: 0.004967579618096352, Accuracy: 0.9983296035805627\n",
      "Iter 447 / 2000, Loss: 124067939.09375, CrossEntropy: 0.00441751629114151, Accuracy: 0.9984215153452686\n",
      "Iter 448 / 2000, Loss: 125768214.09375, CrossEntropy: 0.0050914091989398, Accuracy: 0.9983415920716112\n",
      "Iter 449 / 2000, Loss: 130547191.09375, CrossEntropy: 0.006997543387115002, Accuracy: 0.9976622442455243\n",
      "Iter 450 / 2000, Loss: 128846125.3125, CrossEntropy: 0.006323189940303564, Accuracy: 0.9980099104859336\n",
      "Iter 451 / 2000, Loss: 125813614.875, CrossEntropy: 0.005090523045510054, Accuracy: 0.9982416879795396\n",
      "Iter 452 / 2000, Loss: 126069229.09375, CrossEntropy: 0.005296754650771618, Accuracy: 0.998229699488491\n",
      "Iter 453 / 2000, Loss: 127960023.65625, CrossEntropy: 0.005935882218182087, Accuracy: 0.998141783887468\n",
      "Iter 454 / 2000, Loss: 126989653.625, CrossEntropy: 0.005544747691601515, Accuracy: 0.9980418797953964\n",
      "Iter 455 / 2000, Loss: 128843619.75, CrossEntropy: 0.006282296497374773, Accuracy: 0.9979219948849105\n",
      "Iter 456 / 2000, Loss: 126969389.5625, CrossEntropy: 0.005524771753698587, Accuracy: 0.9981617647058824\n",
      "Iter 457 / 2000, Loss: 125516020.28125, CrossEntropy: 0.004938625730574131, Accuracy: 0.9984015345268542\n",
      "Iter 458 / 2000, Loss: 126696804.46875, CrossEntropy: 0.005427811294794083, Accuracy: 0.9983375959079285\n",
      "Iter 459 / 2000, Loss: 126212193.96875, CrossEntropy: 0.00520332669839263, Accuracy: 0.9984215153452686\n",
      "Iter 460 / 2000, Loss: 126687659.21875, CrossEntropy: 0.00538856303319335, Accuracy: 0.998201726342711\n",
      "Iter 461 / 2000, Loss: 128764040.0, CrossEntropy: 0.0062315622344613075, Accuracy: 0.9979499680306906\n",
      "Iter 462 / 2000, Loss: 127230555.09375, CrossEntropy: 0.0055946651846170425, Accuracy: 0.9981218030690537\n",
      "Iter 463 / 2000, Loss: 126765575.09375, CrossEntropy: 0.0054007843136787415, Accuracy: 0.9985014386189258\n",
      "Iter 464 / 2000, Loss: 124990094.15625, CrossEntropy: 0.004706294741481543, Accuracy: 0.9982696611253197\n",
      "Iter 465 / 2000, Loss: 124025315.0625, CrossEntropy: 0.004297106061130762, Accuracy: 0.9986213235294118\n",
      "Iter 466 / 2000, Loss: 125457158.84375, CrossEntropy: 0.00486451992765069, Accuracy: 0.998321611253197\n",
      "Iter 467 / 2000, Loss: 126551412.0, CrossEntropy: 0.005314547568559647, Accuracy: 0.9983895460358057\n",
      "Iter 468 / 2000, Loss: 124571416.15625, CrossEntropy: 0.004498851951211691, Accuracy: 0.9984814578005116\n",
      "Iter 469 / 2000, Loss: 126856469.625, CrossEntropy: 0.005402932874858379, Accuracy: 0.9980618606138107\n",
      "Iter 470 / 2000, Loss: 128656988.03125, CrossEntropy: 0.006156452931463718, Accuracy: 0.9979499680306906\n",
      "Iter 471 / 2000, Loss: 123471206.46875, CrossEntropy: 0.004039142280817032, Accuracy: 0.9987212276214834\n",
      "Iter 472 / 2000, Loss: 128280534.53125, CrossEntropy: 0.005956306587904692, Accuracy: 0.9980019181585678\n",
      "Iter 473 / 2000, Loss: 124252079.1875, CrossEntropy: 0.004342294298112392, Accuracy: 0.9986812659846548\n",
      "Iter 474 / 2000, Loss: 125395738.3125, CrossEntropy: 0.004791663959622383, Accuracy: 0.9984414961636828\n",
      "Iter 475 / 2000, Loss: 127175047.46875, CrossEntropy: 0.005496473051607609, Accuracy: 0.9980418797953964\n",
      "Iter 476 / 2000, Loss: 126087623.65625, CrossEntropy: 0.005056016147136688, Accuracy: 0.9981617647058824\n",
      "Iter 477 / 2000, Loss: 125511371.03125, CrossEntropy: 0.004819713998585939, Accuracy: 0.998321611253197\n",
      "Iter 478 / 2000, Loss: 128221944.125, CrossEntropy: 0.005896976683288813, Accuracy: 0.9981018222506394\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iter 479 / 2000, Loss: 126307168.4375, CrossEntropy: 0.005126962438225746, Accuracy: 0.9982816496163683\n",
      "Iter 480 / 2000, Loss: 126073132.875, CrossEntropy: 0.005054300185292959, Accuracy: 0.9981697570332481\n",
      "Iter 481 / 2000, Loss: 127524509.875, CrossEntropy: 0.00560176745057106, Accuracy: 0.9983016304347826\n",
      "Iter 482 / 2000, Loss: 123048457.5625, CrossEntropy: 0.003819049336016178, Accuracy: 0.9984295076726343\n",
      "Iter 483 / 2000, Loss: 126786610.34375, CrossEntropy: 0.0052956449799239635, Accuracy: 0.9981018222506394\n",
      "Iter 484 / 2000, Loss: 127307703.4375, CrossEntropy: 0.005497869104146957, Accuracy: 0.998141783887468\n",
      "Iter 485 / 2000, Loss: 124382579.96875, CrossEntropy: 0.004324048757553101, Accuracy: 0.9986013427109974\n",
      "Iter 486 / 2000, Loss: 127318503.21875, CrossEntropy: 0.005493671167641878, Accuracy: 0.9980618606138107\n",
      "Iter 487 / 2000, Loss: 124602222.9375, CrossEntropy: 0.004400414414703846, Accuracy: 0.9982816496163683\n",
      "Iter 488 / 2000, Loss: 129268136.875, CrossEntropy: 0.006262519396841526, Accuracy: 0.9980818414322251\n",
      "Iter 489 / 2000, Loss: 125357657.46875, CrossEntropy: 0.004697308409959078, Accuracy: 0.9985014386189258\n",
      "Iter 490 / 2000, Loss: 126873520.3125, CrossEntropy: 0.005292088724672794, Accuracy: 0.9981817455242967\n",
      "Iter 491 / 2000, Loss: 124590574.34375, CrossEntropy: 0.004373028874397278, Accuracy: 0.9985613810741688\n",
      "Iter 492 / 2000, Loss: 126189338.25, CrossEntropy: 0.005005632061511278, Accuracy: 0.998261668797954\n",
      "Iter 493 / 2000, Loss: 125618530.625, CrossEntropy: 0.004784366115927696, Accuracy: 0.998349584398977\n",
      "Iter 494 / 2000, Loss: 126330811.84375, CrossEntropy: 0.005050089210271835, Accuracy: 0.9982816496163683\n",
      "Iter 495 / 2000, Loss: 124053687.0625, CrossEntropy: 0.004135235212743282, Accuracy: 0.998641304347826\n",
      "Iter 496 / 2000, Loss: 127480223.25, CrossEntropy: 0.005498634185642004, Accuracy: 0.9981617647058824\n",
      "Iter 497 / 2000, Loss: 128111827.5, CrossEntropy: 0.005745710805058479, Accuracy: 0.9979619565217391\n",
      "Iter 498 / 2000, Loss: 122626752.0625, CrossEntropy: 0.0035485131666064262, Accuracy: 0.9986013427109974\n",
      "Iter 499 / 2000, Loss: 126933363.84375, CrossEntropy: 0.005265115760266781, Accuracy: 0.9985613810741688\n",
      "Iter 500 / 2000, Loss: 125384300.71875, CrossEntropy: 0.004666561726480722, Accuracy: 0.9984295076726343\n",
      "Iter 501 / 2000, Loss: 128383695.90625, CrossEntropy: 0.005835938733071089, Accuracy: 0.9980218989769821\n",
      "Iter 502 / 2000, Loss: 122027522.0, CrossEntropy: 0.0032881535589694977, Accuracy: 0.9989010549872123\n",
      "Iter 503 / 2000, Loss: 128083542.53125, CrossEntropy: 0.005701575428247452, Accuracy: 0.9981617647058824\n",
      "Iter 504 / 2000, Loss: 125091912.8125, CrossEntropy: 0.004500953946262598, Accuracy: 0.9984215153452686\n",
      "Iter 505 / 2000, Loss: 128200058.40625, CrossEntropy: 0.005738408770412207, Accuracy: 0.9979020140664961\n",
      "Iter 506 / 2000, Loss: 124809476.28125, CrossEntropy: 0.004385299980640411, Accuracy: 0.9985014386189258\n",
      "Iter 507 / 2000, Loss: 129979405.78125, CrossEntropy: 0.006437588017433882, Accuracy: 0.9979819373401535\n",
      "Iter 508 / 2000, Loss: 123263815.0, CrossEntropy: 0.003748709335923195, Accuracy: 0.9987811700767263\n",
      "Iter 509 / 2000, Loss: 130023377.875, CrossEntropy: 0.006443183869123459, Accuracy: 0.998141783887468\n",
      "Iter 510 / 2000, Loss: 126543230.625, CrossEntropy: 0.005046405829489231, Accuracy: 0.9985014386189258\n",
      "Iter 511 / 2000, Loss: 129787779.53125, CrossEntropy: 0.006338386330753565, Accuracy: 0.9979819373401535\n",
      "Iter 512 / 2000, Loss: 126310892.375, CrossEntropy: 0.004945010878145695, Accuracy: 0.9983615728900256\n",
      "Iter 513 / 2000, Loss: 126467866.0, CrossEntropy: 0.005009095184504986, Accuracy: 0.9984694693094629\n",
      "Iter 514 / 2000, Loss: 127500723.65625, CrossEntropy: 0.005458479281514883, Accuracy: 0.9981897378516624\n",
      "Iter 515 / 2000, Loss: 124335657.875, CrossEntropy: 0.004136516246944666, Accuracy: 0.998701246803069\n",
      "Iter 516 / 2000, Loss: 126427909.625, CrossEntropy: 0.004967804998159409, Accuracy: 0.9984015345268542\n",
      "Iter 517 / 2000, Loss: 127233399.15625, CrossEntropy: 0.005283795762807131, Accuracy: 0.998201726342711\n",
      "Iter 518 / 2000, Loss: 126561046.25, CrossEntropy: 0.005011492874473333, Accuracy: 0.9984614769820972\n",
      "Iter 519 / 2000, Loss: 127128829.15625, CrossEntropy: 0.005232539027929306, Accuracy: 0.9982217071611253\n",
      "Iter 520 / 2000, Loss: 129570674.6875, CrossEntropy: 0.0062017543241381645, Accuracy: 0.9981018222506394\n",
      "Iter 521 / 2000, Loss: 124266830.6875, CrossEntropy: 0.00407812325283885, Accuracy: 0.9985414002557544\n",
      "Iter 522 / 2000, Loss: 125539065.0625, CrossEntropy: 0.004579988773912191, Accuracy: 0.9986812659846548\n",
      "Iter 523 / 2000, Loss: 124686003.03125, CrossEntropy: 0.0042342194356024265, Accuracy: 0.9985613810741688\n",
      "Iter 524 / 2000, Loss: 127788662.53125, CrossEntropy: 0.0054856594651937485, Accuracy: 0.9980099104859336\n",
      "Iter 525 / 2000, Loss: 125628579.15625, CrossEntropy: 0.004598652943968773, Accuracy: 0.9984814578005116\n",
      "Iter 526 / 2000, Loss: 124975508.84375, CrossEntropy: 0.004338380880653858, Accuracy: 0.9986812659846548\n",
      "Iter 527 / 2000, Loss: 129440961.59375, CrossEntropy: 0.00611129542812705, Accuracy: 0.9980218989769821\n",
      "Iter 528 / 2000, Loss: 126383659.65625, CrossEntropy: 0.004890224896371365, Accuracy: 0.9983415920716112\n",
      "Iter 529 / 2000, Loss: 123848035.09375, CrossEntropy: 0.0038650506176054478, Accuracy: 0.9986612851662404\n",
      "Iter 530 / 2000, Loss: 123818112.03125, CrossEntropy: 0.00384855386801064, Accuracy: 0.9986013427109974\n",
      "Iter 531 / 2000, Loss: 127513727.75, CrossEntropy: 0.005319246556609869, Accuracy: 0.998261668797954\n",
      "Iter 532 / 2000, Loss: 128007430.40625, CrossEntropy: 0.005512123927474022, Accuracy: 0.9979419757033248\n",
      "Iter 533 / 2000, Loss: 124930272.125, CrossEntropy: 0.004275300074368715, Accuracy: 0.9984015345268542\n",
      "Iter 534 / 2000, Loss: 127081844.5625, CrossEntropy: 0.005129612982273102, Accuracy: 0.9981218030690537\n",
      "Iter 535 / 2000, Loss: 125042844.28125, CrossEntropy: 0.004308921284973621, Accuracy: 0.9986013427109974\n",
      "Iter 536 / 2000, Loss: 126585231.46875, CrossEntropy: 0.004920481238514185, Accuracy: 0.9983615728900256\n",
      "Iter 537 / 2000, Loss: 125387144.21875, CrossEntropy: 0.00443722540512681, Accuracy: 0.9987412084398977\n",
      "Iter 538 / 2000, Loss: 126226302.34375, CrossEntropy: 0.004766364581882954, Accuracy: 0.9984614769820972\n",
      "Iter 539 / 2000, Loss: 126881277.78125, CrossEntropy: 0.005022738594561815, Accuracy: 0.9982416879795396\n",
      "Iter 540 / 2000, Loss: 126548199.0, CrossEntropy: 0.004884207621216774, Accuracy: 0.9984614769820972\n",
      "Iter 541 / 2000, Loss: 126194345.5, CrossEntropy: 0.004738352261483669, Accuracy: 0.998701246803069\n",
      "Iter 542 / 2000, Loss: 125410089.875, CrossEntropy: 0.004419348668307066, Accuracy: 0.998641304347826\n",
      "Iter 543 / 2000, Loss: 128110304.8125, CrossEntropy: 0.005492789670825005, Accuracy: 0.998261668797954\n",
      "Iter 544 / 2000, Loss: 126451382.5625, CrossEntropy: 0.00482397573068738, Accuracy: 0.9985613810741688\n",
      "Iter 545 / 2000, Loss: 126524616.09375, CrossEntropy: 0.004849538207054138, Accuracy: 0.998321611253197\n",
      "Iter 546 / 2000, Loss: 127587101.34375, CrossEntropy: 0.00526784872636199, Accuracy: 0.9981817455242967\n",
      "Iter 547 / 2000, Loss: 126686169.8125, CrossEntropy: 0.0049019125290215015, Accuracy: 0.9985214194373402\n",
      "Iter 548 / 2000, Loss: 126271234.5625, CrossEntropy: 0.00473072798922658, Accuracy: 0.9984215153452686\n",
      "Iter 549 / 2000, Loss: 123176783.75, CrossEntropy: 0.0034905129577964544, Accuracy: 0.9987811700767263\n",
      "Iter 550 / 2000, Loss: 126748458.1875, CrossEntropy: 0.004910553805530071, Accuracy: 0.9983615728900256\n",
      "Iter 551 / 2000, Loss: 125842769.65625, CrossEntropy: 0.004559810739010572, Accuracy: 0.9984694693094629\n",
      "Iter 552 / 2000, Loss: 125617765.78125, CrossEntropy: 0.004448324907571077, Accuracy: 0.9986213235294118\n",
      "Iter 553 / 2000, Loss: 124436269.6875, CrossEntropy: 0.003971691709011793, Accuracy: 0.998641304347826\n",
      "Iter 554 / 2000, Loss: 127342352.75, CrossEntropy: 0.005176811013370752, Accuracy: 0.998289641943734\n",
      "Iter 555 / 2000, Loss: 124815091.09375, CrossEntropy: 0.0041370210237801075, Accuracy: 0.9987691815856777\n",
      "Iter 556 / 2000, Loss: 127125679.46875, CrossEntropy: 0.005036534741520882, Accuracy: 0.9982416879795396\n",
      "Iter 557 / 2000, Loss: 123454201.28125, CrossEntropy: 0.0035573849454522133, Accuracy: 0.9986213235294118\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iter 558 / 2000, Loss: 124478072.40625, CrossEntropy: 0.003961860202252865, Accuracy: 0.9986812659846548\n",
      "Iter 559 / 2000, Loss: 131341758.75, CrossEntropy: 0.006699005141854286, Accuracy: 0.9981018222506394\n",
      "Iter 560 / 2000, Loss: 126890192.75, CrossEntropy: 0.004914476536214352, Accuracy: 0.9984015345268542\n",
      "Iter 561 / 2000, Loss: 127364188.53125, CrossEntropy: 0.005099470727145672, Accuracy: 0.9983615728900256\n",
      "Iter 562 / 2000, Loss: 127448984.1875, CrossEntropy: 0.00512884883210063, Accuracy: 0.9984814578005116\n",
      "Iter 563 / 2000, Loss: 127821093.9375, CrossEntropy: 0.005292982794344425, Accuracy: 0.9983096227621484\n",
      "Iter 564 / 2000, Loss: 126305248.625, CrossEntropy: 0.004689032211899757, Accuracy: 0.9983096227621484\n",
      "Iter 565 / 2000, Loss: 127276631.3125, CrossEntropy: 0.005043043754994869, Accuracy: 0.998261668797954\n",
      "Iter 566 / 2000, Loss: 125325468.34375, CrossEntropy: 0.004322207998484373, Accuracy: 0.998617327365729\n",
      "Iter 567 / 2000, Loss: 127482103.96875, CrossEntropy: 0.005114630796015263, Accuracy: 0.9985214194373402\n",
      "Iter 568 / 2000, Loss: 125949836.3125, CrossEntropy: 0.00449947826564312, Accuracy: 0.998641304347826\n",
      "Iter 569 / 2000, Loss: 124963694.8125, CrossEntropy: 0.004097841214388609, Accuracy: 0.9986213235294118\n",
      "Iter 570 / 2000, Loss: 130418116.96875, CrossEntropy: 0.006316803395748138, Accuracy: 0.9981297953964194\n",
      "Iter 571 / 2000, Loss: 124376676.40625, CrossEntropy: 0.0038529164157807827, Accuracy: 0.9986013427109974\n",
      "Iter 572 / 2000, Loss: 125494208.15625, CrossEntropy: 0.004294674843549728, Accuracy: 0.9986812659846548\n",
      "Iter 573 / 2000, Loss: 129673647.25, CrossEntropy: 0.00599024910479784, Accuracy: 0.9980698529411764\n",
      "Iter 574 / 2000, Loss: 124367434.875, CrossEntropy: 0.003841965924948454, Accuracy: 0.9989410166240409\n",
      "Iter 575 / 2000, Loss: 124263898.6875, CrossEntropy: 0.0038629991468042135, Accuracy: 0.9986373081841433\n",
      "Iter 576 / 2000, Loss: 128920924.75, CrossEntropy: 0.005642613861709833, Accuracy: 0.9982816496163683\n",
      "Iter 577 / 2000, Loss: 125704602.5625, CrossEntropy: 0.004353011958301067, Accuracy: 0.998821131713555\n",
      "Iter 578 / 2000, Loss: 124718378.96875, CrossEntropy: 0.003953456878662109, Accuracy: 0.9986213235294118\n",
      "Iter 579 / 2000, Loss: 126024161.75, CrossEntropy: 0.004496055189520121, Accuracy: 0.9984694693094629\n",
      "Iter 580 / 2000, Loss: 125163422.25, CrossEntropy: 0.004121182486414909, Accuracy: 0.9986013427109974\n",
      "Iter 581 / 2000, Loss: 127402834.03125, CrossEntropy: 0.0050776260904967785, Accuracy: 0.9983695652173913\n",
      "Iter 582 / 2000, Loss: 126495220.4375, CrossEntropy: 0.004644540138542652, Accuracy: 0.9984614769820972\n",
      "Iter 583 / 2000, Loss: 125199626.8125, CrossEntropy: 0.004121971316635609, Accuracy: 0.99838155370844\n",
      "Iter 584 / 2000, Loss: 127342068.71875, CrossEntropy: 0.004971739370375872, Accuracy: 0.9984015345268542\n",
      "Iter 585 / 2000, Loss: 124779967.375, CrossEntropy: 0.003943432588130236, Accuracy: 0.998641304347826\n",
      "Iter 586 / 2000, Loss: 126907715.09375, CrossEntropy: 0.004797478672116995, Accuracy: 0.99840952685422\n",
      "Iter 587 / 2000, Loss: 124684159.09375, CrossEntropy: 0.0038951041642576456, Accuracy: 0.9985414002557544\n",
      "Iter 588 / 2000, Loss: 124490642.46875, CrossEntropy: 0.003840389661490917, Accuracy: 0.9987492007672635\n",
      "Iter 589 / 2000, Loss: 126814319.40625, CrossEntropy: 0.004737361799925566, Accuracy: 0.9985414002557544\n",
      "Iter 590 / 2000, Loss: 128675681.21875, CrossEntropy: 0.005485003348439932, Accuracy: 0.9980818414322251\n",
      "Iter 591 / 2000, Loss: 126305033.53125, CrossEntropy: 0.004524285439401865, Accuracy: 0.9985214194373402\n",
      "Iter 592 / 2000, Loss: 126477118.375, CrossEntropy: 0.004586267285048962, Accuracy: 0.9984414961636828\n",
      "Iter 593 / 2000, Loss: 126938312.46875, CrossEntropy: 0.004765992518514395, Accuracy: 0.9985613810741688\n",
      "Iter 594 / 2000, Loss: 127232532.375, CrossEntropy: 0.004877092316746712, Accuracy: 0.9984614769820972\n",
      "Iter 595 / 2000, Loss: 125280184.09375, CrossEntropy: 0.004091794602572918, Accuracy: 0.998701246803069\n",
      "Iter 596 / 2000, Loss: 127005584.6875, CrossEntropy: 0.004779748152941465, Accuracy: 0.9984414961636828\n",
      "Iter 597 / 2000, Loss: 125679212.625, CrossEntropy: 0.004242207854986191, Accuracy: 0.9985214194373402\n",
      "Iter 598 / 2000, Loss: 127547293.9375, CrossEntropy: 0.004994410090148449, Accuracy: 0.9983895460358057\n",
      "Iter 599 / 2000, Loss: 124070694.59375, CrossEntropy: 0.0036185041535645723, Accuracy: 0.9988091432225065\n",
      "Iter 600 / 2000, Loss: 126459984.0625, CrossEntropy: 0.004544353112578392, Accuracy: 0.9985414002557544\n",
      "Iter 601 / 2000, Loss: 127796204.03125, CrossEntropy: 0.00506643345579505, Accuracy: 0.9984015345268542\n",
      "Iter 602 / 2000, Loss: 125006465.3125, CrossEntropy: 0.00394681328907609, Accuracy: 0.9987212276214834\n",
      "Iter 603 / 2000, Loss: 126368655.53125, CrossEntropy: 0.004486253950744867, Accuracy: 0.9985414002557544\n",
      "Iter 604 / 2000, Loss: 126223905.84375, CrossEntropy: 0.004464804194867611, Accuracy: 0.9985893542199489\n",
      "Iter 605 / 2000, Loss: 124684139.90625, CrossEntropy: 0.003803204046562314, Accuracy: 0.9987212276214834\n",
      "Iter 606 / 2000, Loss: 125770898.09375, CrossEntropy: 0.004233142826706171, Accuracy: 0.9987811700767263\n",
      "Iter 607 / 2000, Loss: 126733810.0, CrossEntropy: 0.00461215665563941, Accuracy: 0.9985214194373402\n",
      "Iter 608 / 2000, Loss: 124619186.40625, CrossEntropy: 0.00376279279589653, Accuracy: 0.9986213235294118\n",
      "Iter 609 / 2000, Loss: 124112343.21875, CrossEntropy: 0.0035562566481530666, Accuracy: 0.9987412084398977\n",
      "Iter 610 / 2000, Loss: 126901826.34375, CrossEntropy: 0.004666390363126993, Accuracy: 0.9984015345268542\n",
      "Iter 611 / 2000, Loss: 125668017.15625, CrossEntropy: 0.004167474806308746, Accuracy: 0.998701246803069\n",
      "Iter 612 / 2000, Loss: 129025058.375, CrossEntropy: 0.005504612345248461, Accuracy: 0.998261668797954\n",
      "Iter 613 / 2000, Loss: 124053588.84375, CrossEntropy: 0.0035132949706166983, Accuracy: 0.9988610933503836\n",
      "Iter 614 / 2000, Loss: 124474558.71875, CrossEntropy: 0.0036763797979801893, Accuracy: 0.9990009590792839\n",
      "Iter 615 / 2000, Loss: 130218251.65625, CrossEntropy: 0.0059683481231331825, Accuracy: 0.998141783887468\n",
      "Iter 616 / 2000, Loss: 125428669.09375, CrossEntropy: 0.00404773373156786, Accuracy: 0.9986812659846548\n",
      "Iter 617 / 2000, Loss: 125587439.875, CrossEntropy: 0.0041063702665269375, Accuracy: 0.9985813618925832\n",
      "Iter 618 / 2000, Loss: 125290193.5625, CrossEntropy: 0.003982879221439362, Accuracy: 0.9988610933503836\n",
      "Iter 619 / 2000, Loss: 126315188.875, CrossEntropy: 0.00438902759924531, Accuracy: 0.9984614769820972\n",
      "Iter 620 / 2000, Loss: 125347715.21875, CrossEntropy: 0.004014418926090002, Accuracy: 0.9986093350383632\n",
      "Iter 621 / 2000, Loss: 126011940.875, CrossEntropy: 0.004255978856235743, Accuracy: 0.9986013427109974\n",
      "Iter 622 / 2000, Loss: 126406361.21875, CrossEntropy: 0.004415118135511875, Accuracy: 0.9985613810741688\n",
      "Iter 623 / 2000, Loss: 125079584.78125, CrossEntropy: 0.003921779338270426, Accuracy: 0.9988091432225065\n",
      "Iter 624 / 2000, Loss: 125720969.78125, CrossEntropy: 0.004126184619963169, Accuracy: 0.998641304347826\n",
      "Iter 625 / 2000, Loss: 127025070.875, CrossEntropy: 0.004643185064196587, Accuracy: 0.9984215153452686\n",
      "Iter 626 / 2000, Loss: 126997373.96875, CrossEntropy: 0.0046269516460597515, Accuracy: 0.9984614769820972\n",
      "Iter 627 / 2000, Loss: 126713652.1875, CrossEntropy: 0.00451196264475584, Accuracy: 0.9981817455242967\n",
      "Iter 628 / 2000, Loss: 126525468.6875, CrossEntropy: 0.004429022781550884, Accuracy: 0.9987412084398977\n",
      "Iter 629 / 2000, Loss: 126215017.71875, CrossEntropy: 0.0042997440323233604, Accuracy: 0.9985214194373402\n",
      "Iter 630 / 2000, Loss: 126862083.0, CrossEntropy: 0.004553818143904209, Accuracy: 0.9985613810741688\n",
      "Iter 631 / 2000, Loss: 126022218.46875, CrossEntropy: 0.004214976914227009, Accuracy: 0.99838155370844\n",
      "Iter 632 / 2000, Loss: 127357138.4375, CrossEntropy: 0.004741596523672342, Accuracy: 0.9984015345268542\n",
      "Iter 633 / 2000, Loss: 125884938.125, CrossEntropy: 0.004229424521327019, Accuracy: 0.998617327365729\n",
      "Iter 634 / 2000, Loss: 126020409.96875, CrossEntropy: 0.004196367226541042, Accuracy: 0.998761189258312\n",
      "Iter 635 / 2000, Loss: 127246478.4375, CrossEntropy: 0.0046821641735732555, Accuracy: 0.9984015345268542\n",
      "Iter 636 / 2000, Loss: 124695239.6875, CrossEntropy: 0.0036574651021510363, Accuracy: 0.9986213235294118\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iter 637 / 2000, Loss: 125256249.4375, CrossEntropy: 0.0038826174568384886, Accuracy: 0.9988011508951407\n",
      "Iter 638 / 2000, Loss: 126519320.3125, CrossEntropy: 0.004376781638711691, Accuracy: 0.9986213235294118\n",
      "Iter 639 / 2000, Loss: 129263471.0, CrossEntropy: 0.005467970389872789, Accuracy: 0.9983615728900256\n",
      "Iter 640 / 2000, Loss: 126938967.40625, CrossEntropy: 0.004534845240414143, Accuracy: 0.998701246803069\n",
      "Iter 641 / 2000, Loss: 125539605.3125, CrossEntropy: 0.003971392288804054, Accuracy: 0.9987212276214834\n",
      "Iter 642 / 2000, Loss: 129362764.46875, CrossEntropy: 0.0055023180320858955, Accuracy: 0.998349584398977\n",
      "Iter 643 / 2000, Loss: 125807372.65625, CrossEntropy: 0.004068497568368912, Accuracy: 0.998761189258312\n",
      "Iter 644 / 2000, Loss: 127444076.125, CrossEntropy: 0.004716977011412382, Accuracy: 0.9985414002557544\n",
      "Iter 645 / 2000, Loss: 126276055.84375, CrossEntropy: 0.004259483888745308, Accuracy: 0.9986492966751919\n",
      "Iter 646 / 2000, Loss: 128022453.34375, CrossEntropy: 0.004938721191138029, Accuracy: 0.9985414002557544\n",
      "Iter 647 / 2000, Loss: 129304729.75, CrossEntropy: 0.00544627895578742, Accuracy: 0.9983415920716112\n",
      "Iter 648 / 2000, Loss: 125002553.375, CrossEntropy: 0.0037220853846520185, Accuracy: 0.9987412084398977\n",
      "Iter 649 / 2000, Loss: 123833282.40625, CrossEntropy: 0.003256449243053794, Accuracy: 0.9988610933503836\n",
      "Iter 650 / 2000, Loss: 125207526.0, CrossEntropy: 0.003796095959842205, Accuracy: 0.9986213235294118\n",
      "Iter 651 / 2000, Loss: 125860547.46875, CrossEntropy: 0.004050564952194691, Accuracy: 0.9985014386189258\n",
      "Iter 652 / 2000, Loss: 125831641.625, CrossEntropy: 0.0040352558717131615, Accuracy: 0.9987412084398977\n",
      "Iter 653 / 2000, Loss: 126715700.78125, CrossEntropy: 0.0043833437375724316, Accuracy: 0.9986612851662404\n",
      "Iter 654 / 2000, Loss: 124903391.25, CrossEntropy: 0.0036552082747220993, Accuracy: 0.9988411125319693\n",
      "Iter 655 / 2000, Loss: 126558956.25, CrossEntropy: 0.004391956143081188, Accuracy: 0.9985094309462916\n",
      "Iter 656 / 2000, Loss: 124920173.375, CrossEntropy: 0.0036519095301628113, Accuracy: 0.9988810741687979\n",
      "Iter 657 / 2000, Loss: 125417159.09375, CrossEntropy: 0.0038461301010102034, Accuracy: 0.9987212276214834\n",
      "Iter 658 / 2000, Loss: 124588053.8125, CrossEntropy: 0.0035127245355397463, Accuracy: 0.9987412084398977\n",
      "Iter 659 / 2000, Loss: 127002221.53125, CrossEntropy: 0.004469929728657007, Accuracy: 0.9983415920716112\n",
      "Iter 660 / 2000, Loss: 123974582.75, CrossEntropy: 0.0032561318948864937, Accuracy: 0.9986612851662404\n",
      "Iter 661 / 2000, Loss: 125072854.09375, CrossEntropy: 0.0036911331117153168, Accuracy: 0.998761189258312\n",
      "Iter 662 / 2000, Loss: 126501361.40625, CrossEntropy: 0.004258288536220789, Accuracy: 0.9988411125319693\n",
      "Iter 663 / 2000, Loss: 127489789.28125, CrossEntropy: 0.004647227469831705, Accuracy: 0.9983615728900256\n",
      "Iter 664 / 2000, Loss: 126897877.90625, CrossEntropy: 0.004407181870192289, Accuracy: 0.9987212276214834\n",
      "Iter 665 / 2000, Loss: 128058665.78125, CrossEntropy: 0.004865548107773066, Accuracy: 0.9984414961636828\n",
      "Iter 666 / 2000, Loss: 126928497.3125, CrossEntropy: 0.004408661276102066, Accuracy: 0.9985214194373402\n",
      "Iter 667 / 2000, Loss: 125301773.125, CrossEntropy: 0.0037553037982434034, Accuracy: 0.998641304347826\n",
      "Iter 668 / 2000, Loss: 124362406.3125, CrossEntropy: 0.003375112544745207, Accuracy: 0.9988610933503836\n",
      "Iter 669 / 2000, Loss: 128693243.6875, CrossEntropy: 0.005099436268210411, Accuracy: 0.9983016304347826\n",
      "Iter 670 / 2000, Loss: 125972687.0625, CrossEntropy: 0.004009679891169071, Accuracy: 0.9985613810741688\n",
      "Iter 671 / 2000, Loss: 126066693.46875, CrossEntropy: 0.004147381987422705, Accuracy: 0.9986492966751919\n",
      "Iter 672 / 2000, Loss: 126849771.0, CrossEntropy: 0.004349302034825087, Accuracy: 0.9985414002557544\n",
      "Iter 673 / 2000, Loss: 126196454.53125, CrossEntropy: 0.004083344247192144, Accuracy: 0.9986213235294118\n",
      "Iter 674 / 2000, Loss: 126675536.53125, CrossEntropy: 0.004330122377723455, Accuracy: 0.9984494884910486\n",
      "Iter 675 / 2000, Loss: 130064450.09375, CrossEntropy: 0.005642987322062254, Accuracy: 0.998349584398977\n",
      "Iter 676 / 2000, Loss: 124890644.4375, CrossEntropy: 0.003631831146776676, Accuracy: 0.9987372122762149\n",
      "Iter 677 / 2000, Loss: 127228961.40625, CrossEntropy: 0.0044769844971597195, Accuracy: 0.9986213235294118\n",
      "Iter 678 / 2000, Loss: 127345009.5625, CrossEntropy: 0.004597228951752186, Accuracy: 0.9985374040920717\n",
      "Iter 679 / 2000, Loss: 124903027.125, CrossEntropy: 0.003538735443726182, Accuracy: 0.9987811700767263\n",
      "Iter 680 / 2000, Loss: 125697811.5, CrossEntropy: 0.003857898758724332, Accuracy: 0.9987412084398977\n",
      "Iter 681 / 2000, Loss: 128222556.28125, CrossEntropy: 0.0048597087152302265, Accuracy: 0.9984414961636828\n",
      "Iter 682 / 2000, Loss: 125363114.59375, CrossEntropy: 0.0037219442892819643, Accuracy: 0.9987292199488491\n",
      "Iter 683 / 2000, Loss: 125796379.9375, CrossEntropy: 0.00387581717222929, Accuracy: 0.998701246803069\n",
      "Iter 684 / 2000, Loss: 124235192.34375, CrossEntropy: 0.0032477464992552996, Accuracy: 0.9990009590792839\n",
      "Iter 685 / 2000, Loss: 127859605.15625, CrossEntropy: 0.004713403061032295, Accuracy: 0.9985693734015345\n",
      "Iter 686 / 2000, Loss: 126167956.5, CrossEntropy: 0.00401089433580637, Accuracy: 0.9985813618925832\n",
      "Iter 687 / 2000, Loss: 125879044.71875, CrossEntropy: 0.003890544641762972, Accuracy: 0.9987212276214834\n",
      "Iter 688 / 2000, Loss: 126547107.5, CrossEntropy: 0.0041540698148310184, Accuracy: 0.9987811700767263\n",
      "Iter 689 / 2000, Loss: 126671258.46875, CrossEntropy: 0.004198197741061449, Accuracy: 0.9986612851662404\n",
      "Iter 690 / 2000, Loss: 125979157.5625, CrossEntropy: 0.003921471070498228, Accuracy: 0.9987212276214834\n",
      "Iter 691 / 2000, Loss: 126096145.15625, CrossEntropy: 0.003959137015044689, Accuracy: 0.9987212276214834\n",
      "Iter 692 / 2000, Loss: 125579097.125, CrossEntropy: 0.003747955895960331, Accuracy: 0.998821131713555\n",
      "Iter 693 / 2000, Loss: 127329986.875, CrossEntropy: 0.0044447616674005985, Accuracy: 0.9984614769820972\n",
      "Iter 694 / 2000, Loss: 128747290.71875, CrossEntropy: 0.005004246719181538, Accuracy: 0.9984814578005116\n",
      "Iter 695 / 2000, Loss: 129388189.15625, CrossEntropy: 0.005254887975752354, Accuracy: 0.9984015345268542\n",
      "Iter 696 / 2000, Loss: 125137716.28125, CrossEntropy: 0.0036123183090239763, Accuracy: 0.9989090473145781\n",
      "Iter 697 / 2000, Loss: 126190651.28125, CrossEntropy: 0.004002497531473637, Accuracy: 0.9985693734015345\n",
      "Iter 698 / 2000, Loss: 125359181.1875, CrossEntropy: 0.0036318846978247166, Accuracy: 0.998761189258312\n",
      "Iter 699 / 2000, Loss: 127379957.125, CrossEntropy: 0.0044344221241772175, Accuracy: 0.9985813618925832\n",
      "Iter 700 / 2000, Loss: 127655654.5625, CrossEntropy: 0.004559764638543129, Accuracy: 0.9986892583120205\n",
      "Iter 701 / 2000, Loss: 124909092.0, CrossEntropy: 0.00344644021242857, Accuracy: 0.9987691815856777\n",
      "Iter 702 / 2000, Loss: 127365406.3125, CrossEntropy: 0.004415537230670452, Accuracy: 0.9987811700767263\n",
      "Iter 703 / 2000, Loss: 127452161.8125, CrossEntropy: 0.004445320926606655, Accuracy: 0.9985813618925832\n",
      "Iter 704 / 2000, Loss: 124995169.0, CrossEntropy: 0.0034684999845921993, Accuracy: 0.9987492007672635\n",
      "Iter 705 / 2000, Loss: 126819903.5625, CrossEntropy: 0.004183440003544092, Accuracy: 0.9987412084398977\n",
      "Iter 706 / 2000, Loss: 127103249.6875, CrossEntropy: 0.004292544908821583, Accuracy: 0.9986213235294118\n",
      "Iter 707 / 2000, Loss: 127352766.78125, CrossEntropy: 0.004387154709547758, Accuracy: 0.9984814578005116\n",
      "Iter 708 / 2000, Loss: 127565372.90625, CrossEntropy: 0.004466872662305832, Accuracy: 0.9984414961636828\n",
      "Iter 709 / 2000, Loss: 125192903.84375, CrossEntropy: 0.003514290787279606, Accuracy: 0.9988411125319693\n",
      "Iter 710 / 2000, Loss: 129673618.25, CrossEntropy: 0.005304121412336826, Accuracy: 0.9984015345268542\n",
      "Iter 711 / 2000, Loss: 126247076.15625, CrossEntropy: 0.003931899555027485, Accuracy: 0.998821131713555\n",
      "Iter 712 / 2000, Loss: 125748067.34375, CrossEntropy: 0.003722907043993473, Accuracy: 0.9986612851662404\n",
      "Iter 713 / 2000, Loss: 126000910.6875, CrossEntropy: 0.00381892709992826, Accuracy: 0.9988011508951407\n",
      "Iter 714 / 2000, Loss: 125788504.71875, CrossEntropy: 0.0037300027906894684, Accuracy: 0.998761189258312\n",
      "Iter 715 / 2000, Loss: 124293270.75, CrossEntropy: 0.0031685330905020237, Accuracy: 0.9988570971867008\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iter 716 / 2000, Loss: 129354973.40625, CrossEntropy: 0.005146046634763479, Accuracy: 0.9984015345268542\n",
      "Iter 717 / 2000, Loss: 126013854.3125, CrossEntropy: 0.003812221810221672, Accuracy: 0.9987811700767263\n",
      "Iter 718 / 2000, Loss: 126463488.40625, CrossEntropy: 0.004027649760246277, Accuracy: 0.9987092391304349\n",
      "Iter 719 / 2000, Loss: 123575772.59375, CrossEntropy: 0.0028234284836798906, Accuracy: 0.9989809782608695\n",
      "Iter 720 / 2000, Loss: 126442343.03125, CrossEntropy: 0.003964555915445089, Accuracy: 0.9987412084398977\n",
      "Iter 721 / 2000, Loss: 126725190.3125, CrossEntropy: 0.004072797019034624, Accuracy: 0.998641304347826\n",
      "Iter 722 / 2000, Loss: 127868164.53125, CrossEntropy: 0.004524844698607922, Accuracy: 0.9984414961636828\n",
      "Iter 723 / 2000, Loss: 125540041.96875, CrossEntropy: 0.0035905404947698116, Accuracy: 0.9988411125319693\n",
      "Iter 724 / 2000, Loss: 126248714.5625, CrossEntropy: 0.003869046922773123, Accuracy: 0.998821131713555\n",
      "Iter 725 / 2000, Loss: 127067058.1875, CrossEntropy: 0.004196823574602604, Accuracy: 0.9985014386189258\n",
      "Iter 726 / 2000, Loss: 126024363.09375, CrossEntropy: 0.0037904721684753895, Accuracy: 0.9986892583120205\n",
      "Iter 727 / 2000, Loss: 126287038.96875, CrossEntropy: 0.003874434158205986, Accuracy: 0.9987811700767263\n",
      "Iter 728 / 2000, Loss: 127780562.875, CrossEntropy: 0.004464594181627035, Accuracy: 0.9985613810741688\n",
      "Iter 729 / 2000, Loss: 127085471.15625, CrossEntropy: 0.004182614851742983, Accuracy: 0.9987212276214834\n",
      "Iter 730 / 2000, Loss: 124630493.65625, CrossEntropy: 0.0031970604322850704, Accuracy: 0.9988610933503836\n",
      "Iter 731 / 2000, Loss: 126343211.53125, CrossEntropy: 0.003877441631630063, Accuracy: 0.9985214194373402\n",
      "Iter 732 / 2000, Loss: 125263808.375, CrossEntropy: 0.00344269210472703, Accuracy: 0.9987412084398977\n",
      "Iter 733 / 2000, Loss: 127185334.4375, CrossEntropy: 0.00424582976847887, Accuracy: 0.9985294117647059\n",
      "Iter 734 / 2000, Loss: 123540669.53125, CrossEntropy: 0.002747816266492009, Accuracy: 0.9991008631713555\n",
      "Iter 735 / 2000, Loss: 126105684.84375, CrossEntropy: 0.0037655376363545656, Accuracy: 0.998701246803069\n",
      "Iter 736 / 2000, Loss: 126202571.65625, CrossEntropy: 0.00380675564520061, Accuracy: 0.9987412084398977\n",
      "Iter 737 / 2000, Loss: 126668046.28125, CrossEntropy: 0.003982032183557749, Accuracy: 0.9986612851662404\n",
      "Iter 738 / 2000, Loss: 121833600.90625, CrossEntropy: 0.0020459701772779226, Accuracy: 0.9994005754475703\n",
      "Iter 739 / 2000, Loss: 130596296.1875, CrossEntropy: 0.005543248262256384, Accuracy: 0.9982416879795396\n",
      "Iter 740 / 2000, Loss: 122895063.90625, CrossEntropy: 0.0024631160777062178, Accuracy: 0.9992007672634271\n",
      "Iter 741 / 2000, Loss: 126421916.5625, CrossEntropy: 0.003906558733433485, Accuracy: 0.9987891624040921\n",
      "Iter 742 / 2000, Loss: 124628384.40625, CrossEntropy: 0.0031483087223023176, Accuracy: 0.998821131713555\n",
      "Iter 743 / 2000, Loss: 127807154.75, CrossEntropy: 0.004413787741214037, Accuracy: 0.9987412084398977\n",
      "Iter 744 / 2000, Loss: 126987035.78125, CrossEntropy: 0.004082222934812307, Accuracy: 0.9986812659846548\n",
      "Iter 745 / 2000, Loss: 125253870.40625, CrossEntropy: 0.0033851447515189648, Accuracy: 0.9988610933503836\n",
      "Iter 746 / 2000, Loss: 127970244.65625, CrossEntropy: 0.004479005467146635, Accuracy: 0.9986492966751919\n",
      "Iter 747 / 2000, Loss: 127333675.3125, CrossEntropy: 0.004236765671521425, Accuracy: 0.9985294117647059\n",
      "Iter 748 / 2000, Loss: 128473367.90625, CrossEntropy: 0.004657509736716747, Accuracy: 0.9985414002557544\n",
      "Iter 749 / 2000, Loss: 123756301.78125, CrossEntropy: 0.002769531449303031, Accuracy: 0.9989809782608695\n",
      "Iter 750 / 2000, Loss: 128055263.875, CrossEntropy: 0.00448265578597784, Accuracy: 0.9984814578005116\n",
      "Iter 751 / 2000, Loss: 124762178.25, CrossEntropy: 0.003179699182510376, Accuracy: 0.9988411125319693\n",
      "Iter 752 / 2000, Loss: 127922388.15625, CrossEntropy: 0.004422362893819809, Accuracy: 0.9986812659846548\n",
      "Iter 753 / 2000, Loss: 126475258.9375, CrossEntropy: 0.003840724704787135, Accuracy: 0.9987811700767263\n",
      "Iter 754 / 2000, Loss: 125659975.59375, CrossEntropy: 0.003509286092594266, Accuracy: 0.9988810741687979\n",
      "Iter 755 / 2000, Loss: 125672339.3125, CrossEntropy: 0.003509902162477374, Accuracy: 0.9989410166240409\n",
      "Iter 756 / 2000, Loss: 126727102.65625, CrossEntropy: 0.003927568439394236, Accuracy: 0.9987811700767263\n",
      "Iter 757 / 2000, Loss: 128219911.09375, CrossEntropy: 0.004519154783338308, Accuracy: 0.9985414002557544\n",
      "Iter 758 / 2000, Loss: 126053614.125, CrossEntropy: 0.003675434971228242, Accuracy: 0.9988890664961637\n",
      "Iter 759 / 2000, Loss: 125872430.40625, CrossEntropy: 0.003572432789951563, Accuracy: 0.998761189258312\n",
      "Iter 760 / 2000, Loss: 125024986.65625, CrossEntropy: 0.0032423678785562515, Accuracy: 0.9989490089514067\n",
      "Iter 761 / 2000, Loss: 127438650.96875, CrossEntropy: 0.004191477783024311, Accuracy: 0.998641304347826\n",
      "Iter 762 / 2000, Loss: 126601682.3125, CrossEntropy: 0.0038526987191289663, Accuracy: 0.9986612851662404\n",
      "Iter 763 / 2000, Loss: 129127803.90625, CrossEntropy: 0.004858051892369986, Accuracy: 0.9984015345268542\n",
      "Iter 764 / 2000, Loss: 126170909.0, CrossEntropy: 0.0036752570886164904, Accuracy: 0.9987412084398977\n",
      "Iter 765 / 2000, Loss: 125609244.71875, CrossEntropy: 0.003451381577178836, Accuracy: 0.9988610933503836\n",
      "Iter 766 / 2000, Loss: 124777706.28125, CrossEntropy: 0.0031470907852053642, Accuracy: 0.9990688938618926\n",
      "Iter 767 / 2000, Loss: 129040230.53125, CrossEntropy: 0.0048066009767353535, Accuracy: 0.9986612851662404\n",
      "Iter 768 / 2000, Loss: 127620068.1875, CrossEntropy: 0.0042350031435489655, Accuracy: 0.998761189258312\n",
      "Iter 769 / 2000, Loss: 125348867.59375, CrossEntropy: 0.003323544980958104, Accuracy: 0.9989010549872123\n",
      "Iter 770 / 2000, Loss: 125658503.1875, CrossEntropy: 0.0034468716476112604, Accuracy: 0.9988610933503836\n",
      "Iter 771 / 2000, Loss: 129640722.03125, CrossEntropy: 0.005029517691582441, Accuracy: 0.9983415920716112\n",
      "Iter 772 / 2000, Loss: 127974097.78125, CrossEntropy: 0.0043608457781374454, Accuracy: 0.9987412084398977\n",
      "Iter 773 / 2000, Loss: 125602789.78125, CrossEntropy: 0.003407224314287305, Accuracy: 0.9989410166240409\n",
      "Iter 774 / 2000, Loss: 127264406.25, CrossEntropy: 0.004067215137183666, Accuracy: 0.998641304347826\n",
      "Iter 775 / 2000, Loss: 127890822.375, CrossEntropy: 0.0043200356885790825, Accuracy: 0.9985613810741688\n",
      "Iter 776 / 2000, Loss: 124460582.09375, CrossEntropy: 0.002938805613666773, Accuracy: 0.9990609015345269\n",
      "Iter 777 / 2000, Loss: 127131059.28125, CrossEntropy: 0.004005757160484791, Accuracy: 0.9987212276214834\n",
      "Iter 778 / 2000, Loss: 126513404.46875, CrossEntropy: 0.0037535587325692177, Accuracy: 0.9987412084398977\n",
      "Iter 779 / 2000, Loss: 127147966.21875, CrossEntropy: 0.003998998552560806, Accuracy: 0.9985813618925832\n",
      "Iter 780 / 2000, Loss: 126674755.375, CrossEntropy: 0.0038257131818681955, Accuracy: 0.9987891624040921\n",
      "Iter 781 / 2000, Loss: 127192927.375, CrossEntropy: 0.004008932039141655, Accuracy: 0.998761189258312\n",
      "Iter 782 / 2000, Loss: 127025376.625, CrossEntropy: 0.003955424763262272, Accuracy: 0.9988291240409207\n",
      "Iter 783 / 2000, Loss: 128065295.75, CrossEntropy: 0.004348620306700468, Accuracy: 0.9986812659846548\n",
      "Iter 784 / 2000, Loss: 129885078.625, CrossEntropy: 0.005078965798020363, Accuracy: 0.9985014386189258\n",
      "Iter 785 / 2000, Loss: 125273982.21875, CrossEntropy: 0.0032246734481304884, Accuracy: 0.9989010549872123\n",
      "Iter 786 / 2000, Loss: 126524099.4375, CrossEntropy: 0.003719643922522664, Accuracy: 0.9987811700767263\n",
      "Iter 787 / 2000, Loss: 126973616.96875, CrossEntropy: 0.003930980339646339, Accuracy: 0.9986692774936061\n",
      "Iter 788 / 2000, Loss: 128439052.09375, CrossEntropy: 0.0045282780192792416, Accuracy: 0.9984574808184143\n",
      "Iter 789 / 2000, Loss: 124497484.9375, CrossEntropy: 0.002898377599194646, Accuracy: 0.9991008631713555\n",
      "Iter 790 / 2000, Loss: 128881237.3125, CrossEntropy: 0.004678049124777317, Accuracy: 0.9984494884910486\n",
      "Iter 791 / 2000, Loss: 126044961.9375, CrossEntropy: 0.0035083198454231024, Accuracy: 0.9989809782608695\n",
      "Iter 792 / 2000, Loss: 127806567.1875, CrossEntropy: 0.004210357088595629, Accuracy: 0.9986612851662404\n",
      "Iter 793 / 2000, Loss: 128929212.03125, CrossEntropy: 0.004659191705286503, Accuracy: 0.9986013427109974\n",
      "Iter 794 / 2000, Loss: 125196949.15625, CrossEntropy: 0.003165885806083679, Accuracy: 0.9989809782608695\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iter 795 / 2000, Loss: 128382846.5625, CrossEntropy: 0.004429686348885298, Accuracy: 0.9987212276214834\n",
      "Iter 796 / 2000, Loss: 128664494.78125, CrossEntropy: 0.004534875508397818, Accuracy: 0.9986612851662404\n",
      "Iter 797 / 2000, Loss: 126465434.96875, CrossEntropy: 0.0036519698332995176, Accuracy: 0.9987412084398977\n",
      "Iter 798 / 2000, Loss: 124035277.125, CrossEntropy: 0.0026772315613925457, Accuracy: 0.9990209398976982\n",
      "Iter 799 / 2000, Loss: 127284900.375, CrossEntropy: 0.003971862141042948, Accuracy: 0.9986013427109974\n",
      "Iter 800 / 2000, Loss: 125546098.46875, CrossEntropy: 0.0032731485553085804, Accuracy: 0.9990009590792839\n",
      "Iter 801 / 2000, Loss: 126280906.125, CrossEntropy: 0.003563451115041971, Accuracy: 0.9989210358056266\n",
      "Iter 802 / 2000, Loss: 126217172.15625, CrossEntropy: 0.003533458337187767, Accuracy: 0.9987811700767263\n",
      "Iter 803 / 2000, Loss: 125861540.1875, CrossEntropy: 0.0033875396475195885, Accuracy: 0.9989410166240409\n",
      "Iter 804 / 2000, Loss: 126956757.84375, CrossEntropy: 0.003821279387921095, Accuracy: 0.998761189258312\n",
      "Iter 805 / 2000, Loss: 126351343.25, CrossEntropy: 0.003576570190489292, Accuracy: 0.9987412084398977\n",
      "Iter 806 / 2000, Loss: 128714986.0, CrossEntropy: 0.0045166355557739735, Accuracy: 0.9986013427109974\n",
      "Iter 807 / 2000, Loss: 127489750.25, CrossEntropy: 0.004022585693746805, Accuracy: 0.998761189258312\n",
      "Iter 808 / 2000, Loss: 127047705.78125, CrossEntropy: 0.003841932164505124, Accuracy: 0.998821131713555\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-9-4d021ee3df5a>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     12\u001b[0m \u001b[0;31m#         images = images.view(images.shape[0], -1)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 14\u001b[0;31m         loss, y_pred,_ = sgd_model.training_step(\n\u001b[0m\u001b[1;32m     15\u001b[0m             \u001b[0mbatch\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimages\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     16\u001b[0m             \u001b[0mN\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mN\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Langevin_Variational_Inference/Deep_Nets/CNNs/ResNet/src/components.py\u001b[0m in \u001b[0;36mtraining_step\u001b[0;34m(self, batch, N, vi_batch_size, deterministic_weights)\u001b[0m\n\u001b[1;32m    380\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    381\u001b[0m         \u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mY\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 382\u001b[0;31m         Y_hat, sample_dict = self.sample_pred(\n\u001b[0m\u001b[1;32m    383\u001b[0m             \u001b[0mX\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    384\u001b[0m             \u001b[0mdeterministic\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdeterministic_weights\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Langevin_Variational_Inference/Deep_Nets/CNNs/ResNet/src/components.py\u001b[0m in \u001b[0;36msample_pred\u001b[0;34m(self, X, deterministic, vi_batch_size, for_training)\u001b[0m\n\u001b[1;32m    359\u001b[0m         )\n\u001b[1;32m    360\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 361\u001b[0;31m         Y_hat = self.forward(\n\u001b[0m\u001b[1;32m    362\u001b[0m             \u001b[0mX\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    363\u001b[0m             \u001b[0msampled_weights\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msample_dict\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Langevin_Variational_Inference/Deep_Nets/CNNs/ResNet/src/components.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, X, sampled_weights)\u001b[0m\n\u001b[1;32m    818\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    819\u001b[0m                 \u001b[0mtemp_h\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mconv_b\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtemp_h\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconv_b_weight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 820\u001b[0;31m                 \u001b[0mtemp_h\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mact_func\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbatch_norms\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mconv_b_name\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtemp_h\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mtemp_h\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mtemp_h\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    821\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    822\u001b[0m                 \u001b[0;32mif\u001b[0m \u001b[0ms\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;36m1\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0min_planes\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0mplanes\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "num_epochs = 2000\n",
    "criterion = torch.nn.CrossEntropyLoss()  # loss function\n",
    "\n",
    "for i in range(num_epochs):\n",
    "    losses = []\n",
    "    cross_losses = []\n",
    "    accuracy = []\n",
    "    \n",
    "    for images, labels in trainloader:\n",
    "\n",
    "        # Flatten MNIST images into a 784 long vector\n",
    "#         images = images.view(images.shape[0], -1)\n",
    "\n",
    "        loss, y_pred,_ = sgd_model.training_step(\n",
    "            batch=(images, labels),\n",
    "            N=N,\n",
    "            deterministic_weights=True,\n",
    "            vi_batch_size=None,\n",
    "        )\n",
    "        losses.append(loss)\n",
    "        \n",
    "        cross_loss = criterion(y_pred.squeeze(0), labels)\n",
    "        cross_losses.append(cross_loss)\n",
    "        accuracy.append((torch.max(y_pred.squeeze(0),-1).indices == labels).sum().item() / labels.size(0))\n",
    "        \n",
    "\n",
    "#     if (i+1) % 10**math.floor(math.log10(i+1)) == 0:  # True when i+1 \\in {1, 2, ..., 10, 20, ..., 100, 200, ..., 1000, 2000, ...}\n",
    "    print(\"Iter {} / {}, Loss: {}, CrossEntropy: {}, Accuracy: {}\".format(i+1, num_epochs, sum(losses), sum(cross_losses)/len(cross_losses), sum(accuracy)/len(accuracy)))\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EVALUATION with last weights -> Loss: 12420517.0, CrossEntropy: 1.9860225915908813, Accuracy: 0.8195213607594937\n"
     ]
    }
   ],
   "source": [
    "#testing the model with its last weights \n",
    "\n",
    "losses = []\n",
    "cross_losses = []\n",
    "accuracy = []\n",
    "\n",
    "for images, labels in testloader:\n",
    "    inner_cross_losses = []\n",
    "    inner_accuracy = []\n",
    "\n",
    "    # Flatten MNIST images into a 784 long vector\n",
    "#     images = images.view(images.shape[0], -1)\n",
    "\n",
    "    loss, y_pred = sgd_model.evaluate(batch=(images, labels),\n",
    "                N=N,\n",
    "                num_samples=None,\n",
    "                deterministic_weights=True)\n",
    "\n",
    "    losses.append(loss)\n",
    "    cross_loss = criterion(y_pred.squeeze(0), labels)\n",
    "    inner_cross_losses.append(cross_loss)\n",
    "    inner_accuracy.append((torch.max(y_pred.squeeze(0),-1).indices == labels).sum().item() / labels.size(0))\n",
    "\n",
    "    accuracy.append(sum(inner_accuracy)/len(inner_accuracy))\n",
    "    cross_losses.append(sum(inner_cross_losses)/len(inner_cross_losses))\n",
    "\n",
    "print(\"EVALUATION with last weights -> Loss: {}, CrossEntropy: {}, Accuracy: {}\".format(sum(losses)/len(losses), sum(cross_losses)/len(cross_losses), sum(accuracy)/len(accuracy)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# torch.save(sgd_model.state_dict(), \"./sgd_resnet20_svhn_map.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# LVI or non-LVI models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "291376\n"
     ]
    }
   ],
   "source": [
    "#### import pickle\n",
    "\n",
    "sgd_model = BayesianResNet20(**pickle.load(open(\"./resnet20_sgd_model_params.pickle\", \"rb\")))\n",
    "sgd_model.load_state_dict(torch.load(\"./sgd_resnet20_svhn_map.pt\", map_location=dev))\n",
    "# sgd_model.load_state_dict(torch.load(\"./cnn_svhn_non_lvi.pt\", map_location=dev))\n",
    "\n",
    "\n",
    "num_stoch_params = 0\n",
    "for param in sgd_model.get_stochastic_params():\n",
    "    param_size = 1\n",
    "    for dim in param.shape:\n",
    "        param_size *= dim\n",
    "    num_stoch_params += param_size\n",
    "print(num_stoch_params)\n",
    "\n",
    "lvi_model_params = pickle.load(open(\"./resnet20_sgd_model_params.pickle\", \"rb\"))\n",
    "lvi_model_params[\"group_by_layers\"] = False\n",
    "lvi_model_params[\"use_random_groups\"] = False\n",
    "lvi_model_params[\"use_permuted_groups\"] = True\n",
    "lvi_model_params[\"max_groups\"] = 2\n",
    "lvi_model_params[\"dropout_prob\"] = None\n",
    "lvi_model_params[\"chain_length\"] = 5000\n",
    "lvi_model_params[\"prior_std\"] = 0.3\n",
    "# lvi_model_params[\"output_distribution\"] = \"categorical\"\n",
    "# lvi_model_params[\"output_dist_const_params\"] = dict(scale=1.0)\n",
    "\n",
    "lvi_model_params[\"init_values\"] = {k:v.theta_actual.data for k,v in sgd_model.tensor_dict.items()}\n",
    "del sgd_model\n",
    "\n",
    "lvi_model = BayesianResNet20(**lvi_model_params)\n",
    "\n",
    "lvi_model.initialize_optimizer(\n",
    "    update_determ=False, \n",
    "    update_stoch=True, \n",
    "    lr=1e-3,\n",
    "#     lr=1e-3, \n",
    "    rmsprop=False,\n",
    "    sgd=False, \n",
    "    sgld=False, \n",
    "    psgld=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# dev = torch.device('cpu')\n",
    "# dev = torch.device('cuda:1')\n",
    "\n",
    "lvi_model = lvi_model.to(dev)\n",
    "for n, t in lvi_model.tensor_dict.items():\n",
    "    if isinstance(t, StochasticTensor):\n",
    "        t.prior_dist.loc = t.prior_dist.loc.to(dev)\n",
    "        t.prior_dist.scale = t.prior_dist.scale.to(dev)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Before initialization: tensor([0, 0], device='cuda:0')\n",
      "After initialization: tensor([1, 1], device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "# print(\"Before initialization: {}\".format(sgld_model.num_samples_per_group))\n",
    "# sgld_model.init_chains()\n",
    "# print(\"After initialization: {}\".format(sgld_model.num_samples_per_group))\n",
    "print(\"Before initialization: {}\".format(lvi_model.num_samples_per_group))\n",
    "lvi_model.init_chains()\n",
    "print(\"After initialization: {}\".format(lvi_model.num_samples_per_group))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# for lvi\n",
    "\n",
    "def evaluation(lvi_model, testloader):\n",
    "    losses = []\n",
    "    cross_losses = []\n",
    "    accuracy = []\n",
    "\n",
    "    for images, labels in testloader:\n",
    "        inner_cross_losses = []\n",
    "        inner_accuracy = []\n",
    "\n",
    "        loss, y_pred = lvi_model.evaluate(batch=(images, labels),\n",
    "                    N=N,\n",
    "                    num_samples=100,\n",
    "                    deterministic_weights=False)\n",
    "\n",
    "        losses.append(loss)\n",
    "        for j in range(y_pred.shape[0]):\n",
    "            cross_loss = criterion(y_pred.squeeze(0)[j], labels)\n",
    "            inner_cross_losses.append(cross_loss)\n",
    "            inner_accuracy.append((torch.max(y_pred.squeeze(0)[j],-1).indices == labels).sum().item() / labels.size(0))\n",
    "\n",
    "            accuracy.append(sum(inner_accuracy)/len(inner_accuracy))\n",
    "            cross_losses.append(sum(inner_cross_losses)/len(inner_cross_losses))\n",
    "\n",
    "    print(\"EVALUATION with 100 samples -> Loss: {}, CrossEntropy: {}, Accuracy: {}\".format(sum(losses)/len(losses), sum(cross_losses)/len(cross_losses), sum(accuracy)/len(accuracy)))\n",
    "    return sum(accuracy)/len(accuracy)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/aalexos/Langevin_Variational_Inference/Deep_Nets/CNNs/CIFAR10/src/optimizers.py:90: UserWarning: This overload of addcmul_ is deprecated:\n",
      "\taddcmul_(Number value, Tensor tensor1, Tensor tensor2)\n",
      "Consider using one of the following signatures instead:\n",
      "\taddcmul_(Tensor tensor1, Tensor tensor2, *, Number value) (Triggered internally at  /pytorch/torch/csrc/utils/python_arg_parser.cpp:766.)\n",
      "  square_avg.mul_(alpha).addcmul_(1-alpha, d_p, d_p)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iter 1 / 2, Loss: 33090312.598465472, CrossEntropy: 5.128942966461182, Accuracy: 0.5304747442455243\n",
      "Elapsed time for the training: 15.22326374053955\n",
      "EVALUATION with 100 samples -> Loss: 35828936.0, CrossEntropy: 5.644075870513916, Accuracy: 0.48671953702526055\n",
      "Iter 2 / 2, Loss: 17418829.18286445, CrossEntropy: 2.6807048320770264, Accuracy: 0.5821451406649617\n",
      "Elapsed time for the training: 20.991058111190796\n",
      "EVALUATION with 100 samples -> Loss: 26794088.0, CrossEntropy: 4.279408931732178, Accuracy: 0.5027511831003681\n"
     ]
    }
   ],
   "source": [
    "#for lvi\n",
    "\n",
    "from datetime import datetime\n",
    "import time\n",
    "\n",
    "num_epochs = 2\n",
    "criterion = torch.nn.CrossEntropyLoss()  # loss function\n",
    "total_acc = []\n",
    "# start = time.time()\n",
    "for i in range(num_epochs):\n",
    "    losses = []\n",
    "    cross_losses = []\n",
    "    accuracy = []\n",
    "    \n",
    "#     if (i+1) % 100 == 0:\n",
    "#         print(\"At iteration %d we change the dropout rate from %.1f to %.1f. \" %(i+1, lvi_model.dropout_prob, lvi_model.dropout_prob+0.1))\n",
    "#         lvi_model.change_dropout_rate(lvi_model.dropout_prob+0.1)\n",
    "    \n",
    "    start = time.time()\n",
    "    \n",
    "    for images, labels in trainloader:\n",
    "        inner_cross_losses = []\n",
    "        inner_accuracy = []\n",
    "        \n",
    "        if i < 1:\n",
    "            loss, y_pred,_ = lvi_model.training_step(\n",
    "                batch=(images, labels),\n",
    "                N=N,\n",
    "                deterministic_weights=True,\n",
    "                vi_batch_size=1,\n",
    "            )\n",
    "            losses.append(loss)\n",
    "        \n",
    "            cross_loss = criterion(y_pred.squeeze(0), labels)\n",
    "            cross_losses.append(cross_loss)\n",
    "            accuracy.append((torch.max(y_pred.squeeze(0),-1).indices == labels).sum().item() / labels.size(0))\n",
    "            \n",
    "        else:\n",
    "            loss, y_pred,_ = lvi_model.training_step(\n",
    "                batch=(images, labels),\n",
    "                N=N,\n",
    "                deterministic_weights=False,\n",
    "                vi_batch_size=None,\n",
    "            ) \n",
    "\n",
    "            losses.append(loss)\n",
    "\n",
    "            for j in range(y_pred.shape[0]):\n",
    "                cross_loss = criterion(y_pred.squeeze(0)[j], labels)\n",
    "                inner_cross_losses.append(cross_loss)\n",
    "                inner_accuracy.append((torch.max(y_pred.squeeze(0)[j],-1).indices == labels).sum().item() / labels.size(0))\n",
    "\n",
    "            accuracy.append(sum(inner_accuracy)/len(inner_accuracy))\n",
    "            cross_losses.append(sum(inner_cross_losses)/len(inner_cross_losses))\n",
    "\n",
    "#     if (i+1) % 10**math.floor(math.log10(i+1)) == 0:  # True when i+1 \\in {1, 2, ..., 10, 20, ..., 100, 200, ..., 1000, 2000, ...}\n",
    "    print(\"Iter {} / {}, Loss: {}, CrossEntropy: {}, Accuracy: {}\".format(i+1, num_epochs, sum(losses)/len(losses), sum(cross_losses)/len(cross_losses), sum(accuracy)/len(accuracy)))\n",
    "\n",
    "    end = time.time()\n",
    "    print('Elapsed time for the training:', end - start)\n",
    "    \n",
    "    tmp_acc = evaluation(lvi_model, testloader)\n",
    "    total_acc.append(tmp_acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# torch.save(lvi_model.state_dict(), \"./cnn_svhn_non_lvi.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EVALUATION with 100 samples -> Loss: 11035387.0, CrossEntropy: 3.316812038421631, Accuracy: 0.8987949346405288\n"
     ]
    }
   ],
   "source": [
    "N = len(valset)\n",
    "losses = []\n",
    "cross_losses = []\n",
    "accuracy = []\n",
    "\n",
    "for images, labels in testloader:\n",
    "    inner_cross_losses = []\n",
    "    inner_accuracy = []\n",
    "\n",
    "    # Flatten MNIST images into a 784 long vector\n",
    "#     images = images.view(images.shape[0], -1)\n",
    "\n",
    "    loss, y_pred = lvi_model.evaluate(batch=(images, labels),\n",
    "                N=N,\n",
    "                num_samples=100,\n",
    "                deterministic_weights=False)\n",
    "\n",
    "    losses.append(loss)\n",
    "    for j in range(y_pred.shape[0]):\n",
    "        cross_loss = criterion(y_pred.squeeze(0)[j], labels)\n",
    "        inner_cross_losses.append(cross_loss)\n",
    "        inner_accuracy.append((torch.max(y_pred.squeeze(0)[j],-1).indices == labels).sum().item() / labels.size(0))\n",
    "\n",
    "        accuracy.append(sum(inner_accuracy)/len(inner_accuracy))\n",
    "        cross_losses.append(sum(inner_cross_losses)/len(inner_cross_losses))\n",
    "\n",
    "print(\"EVALUATION with 100 samples -> Loss: {}, CrossEntropy: {}, Accuracy: {}\".format(sum(losses)/len(losses), sum(cross_losses)/len(cross_losses), sum(accuracy)/len(accuracy)))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10,10))\n",
    "plt.plot(range(len(mses)), mses)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lvi_model.use_dropout, lvi_model.num_samples_per_group"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "sns.set_style(\"white\")\n",
    "\n",
    "#to_plot = sgld_model.tensor_dict.W_2.theta_chains.view(sgld_model.chain_length, -1).cpu().detach().numpy()\n",
    "#to_plot = lvi_model.tensor_dict.W_2.theta_chains.view(lvi_model.chain_length, -1).cpu().detach().numpy()\n",
    "to_plot = lvi_model.get_chains()[\"W_0\"].squeeze().cpu().detach().numpy()\n",
    "\n",
    "g = sns.PairGrid(pd.DataFrame(to_plot))\n",
    "g.map_diag(plt.hist, bins=100)\n",
    "\n",
    "def pairgrid_heatmap(x, y, **kws):\n",
    "    cmap = sns.light_palette(kws.pop(\"color\"), as_cmap=True)\n",
    "    plt.hist2d(x, y, cmap=cmap, cmin=1, **kws)\n",
    "\n",
    "g.map_offdiag(pairgrid_heatmap, bins=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "shaping = lvi_model.get_chains()[\"W_0\"].squeeze().cpu().detach().numpy().shape[0]\n",
    "w = pd.DataFrame(lvi_model.get_chains()[\"W_0\"].squeeze().cpu().detach().view(shaping, -1).numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "w.to_csv('non_lvi_drop_10_batch_500.csv',header=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
