{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This code demonstrates that smaller batch sizes do not improve the performance of deep BN networks, as presented in Appendix A.2. of the paper."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Preliminaries and Function Definitions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import os\n",
    "import math\n",
    "import time\n",
    "import pickle\n",
    "from datetime import datetime\n",
    "import re\n",
    "import gc\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.utils.checkpoint as cp\n",
    "from collections import OrderedDict\n",
    "from torch import Tensor\n",
    "from torch.jit.annotations import List\n",
    "import torchvision.datasets as datasets\n",
    "import torchvision.transforms as transforms\n",
    "import matplotlib.pyplot as plt\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from torch.autograd import Variable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# prepare dataset\n",
    "class PrepareData(torch.utils.data.Dataset):\n",
    "    def __init__(self, X, y):\n",
    "        if not torch.is_tensor(X):\n",
    "            self.X = torch.from_numpy(X)\n",
    "        else:\n",
    "            self.X = X\n",
    "        if not torch.is_tensor(y):\n",
    "            self.y = torch.from_numpy(y)\n",
    "        else:\n",
    "            self.y = y\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.X)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.X[idx], self.y[idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# functions for solving the nonconvex batchnorm problem, with a deep convolutional network\n",
    "class BatchNormNetDeep(torch.nn.Module):\n",
    "    def __init__(self, c, num_neurons, d_out, kernel_size=3, padding=1,  momentum=1.0):\n",
    "        super(BatchNormNetDeep, self).__init__()\n",
    "        self.layer1 = torch.nn.Sequential(\n",
    "                    torch.nn.Conv2d(c, num_neurons, kernel_size=kernel_size, padding=padding, bias=False),\n",
    "                    torch.nn.BatchNorm2d(num_neurons, momentum=momentum),\n",
    "                    torch.nn.ReLU(),\n",
    "                    torch.nn.AvgPool2d(2))\n",
    "        \n",
    "        self.layer2 = torch.nn.Sequential(\n",
    "                    torch.nn.Conv2d(num_neurons, num_neurons, kernel_size=kernel_size, padding=padding, bias=False),\n",
    "                    torch.nn.BatchNorm2d(num_neurons, momentum=momentum),\n",
    "                    torch.nn.ReLU(),\n",
    "                    torch.nn.AvgPool2d(2))\n",
    "        self.layer3 = torch.nn.Sequential(\n",
    "                    torch.nn.Conv2d(num_neurons, num_neurons, kernel_size=kernel_size, padding=padding, bias=False),\n",
    "                    torch.nn.BatchNorm2d(num_neurons, momentum=momentum),\n",
    "                    torch.nn.ReLU(),\n",
    "                    torch.nn.AvgPool2d(8))\n",
    "        \n",
    "        self.layer4 = torch.nn.Linear(num_neurons, d_out, bias=False)\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = self.layer3(self.layer2(self.layer1(x)))\n",
    "        out = out.view(out.shape[0], -1)\n",
    "        out = self.layer4(out)\n",
    "\n",
    "        return out\n",
    "\n",
    "def loss_func_deep(yhat, y, model, beta):\n",
    "    loss = 0.5 * torch.norm(yhat - y)**2\n",
    "    \n",
    "    # standard weight decay regularization \n",
    "    for p in model.named_parameters():\n",
    "        loss = loss + beta/2 * torch.norm(p[1])**2\n",
    "    return loss\n",
    "\n",
    "def validation_deep(model, testloader, beta, device):\n",
    "    test_loss = 0\n",
    "    test_correct = 0\n",
    "    \n",
    "    model.eval()\n",
    "\n",
    "    for ix, (_x, _y) in enumerate(testloader):\n",
    "        _x = Variable(_x).float().to(device)\n",
    "        _y = Variable(_y).float().to(device)\n",
    "\n",
    "        output = model.forward(_x)\n",
    "        yhat = model(_x).float()\n",
    "\n",
    "        loss = loss_func_deep(yhat, _y, model, beta)\n",
    "\n",
    "        test_loss += loss.item()\n",
    "        if _y.shape[1] == 1:\n",
    "            test_correct += torch.eq(yhat>0.5, _y).float().sum()\n",
    "        else:\n",
    "            test_correct += torch.eq(torch.argmax(yhat, 1), torch.argmax(_y, 1)).float().sum()\n",
    "\n",
    "    return test_loss, test_correct\n",
    "\n",
    "def batchnorm_full_gd_deep(A_train, y_train, A_test, y_test, num_epochs, num_neurons, beta, \n",
    "                       learning_rate, verbose=False, batch_size=None, bn_momentum=1.0, \n",
    "                      print_freq=5, device='cpu'):\n",
    "    device = torch.device(device)\n",
    "    # D_in is input dimension, H is hidden dimension, D_out is output dimension.\n",
    "    c, H, D_out = A_train.shape[1], num_neurons, y_train.shape[1]\n",
    "    \n",
    "    # use gd if no batch size is passed in \n",
    "    if batch_size is None:\n",
    "        batch_size = A_train.shape[0]\n",
    "        \n",
    "    # create the model\n",
    "    model = BatchNormNetDeep(c, H, D_out, momentum=bn_momentum).to(device)\n",
    "    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)\n",
    "    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, \n",
    "                                                           mode='min', \n",
    "                                                           factor=0.5, \n",
    "                                                           patience=5, \n",
    "                                                           verbose=verbose)\n",
    "    \n",
    "    # arrays for saving the loss and accuracy    \n",
    "    losses = np.zeros((int(num_epochs*np.ceil(A_train.shape[0] / batch_size))))\n",
    "    accs = np.zeros(losses.shape)\n",
    "    losses_test = np.zeros((num_epochs+1))\n",
    "    accs_test = np.zeros((num_epochs+1))\n",
    "    times = np.zeros((losses.shape[0]+1))\n",
    "    times[0] = time.time()\n",
    "\n",
    "    \n",
    "    # dataset loaders (minibatch)\n",
    "    ds = PrepareData(X=A_train, y=y_train)\n",
    "    ds = DataLoader(ds, batch_size=batch_size, shuffle=False)\n",
    "    ds_test = PrepareData(X=A_test, y=y_test)\n",
    "    ds_test = DataLoader(ds_test, batch_size=A_test.shape[0], shuffle=False)\n",
    "    \n",
    "    losses_test[0], accs_test[0] = validation_deep(model, ds_test, beta, device) # loss on the entire test set\n",
    "    \n",
    "    print('starting training')\n",
    "    iter_no = 0\n",
    "    for i in range(num_epochs):\n",
    "        model.train()\n",
    "        for ix, (_x, _y) in enumerate(ds):\n",
    "            _x = Variable(_x).float().to(device)\n",
    "            _y = Variable(_y).float().to(device)\n",
    "            \n",
    "            \n",
    "            yhat = model(_x).float()\n",
    "            \n",
    "            loss = loss_func_deep(yhat, _y, model, beta)\n",
    "            if D_out == 1:\n",
    "                correct = torch.eq(yhat>0.5, _y).float().sum()\n",
    "            else:\n",
    "                correct = torch.eq(torch.argmax(yhat, 1), torch.argmax(_y, 1)).float().sum()\n",
    "            optimizer.zero_grad() # zero the gradients on each pass before the update\n",
    "            loss.backward() # backpropagate the loss through the model\n",
    "            optimizer.step() # update the gradients w.r.t the loss\n",
    "\n",
    "            losses[iter_no] = loss.item() # loss on the minibatch\n",
    "            accs[iter_no] = correct\n",
    "        \n",
    "            iter_no += 1\n",
    "            times[iter_no] = time.time()\n",
    "        \n",
    "        scheduler.step(loss.item())\n",
    "        # get test loss and accuracy\n",
    "        losses_test[i+1], accs_test[i+1] = validation_deep(model, ds_test, beta, device) # loss on the entire test set\n",
    "\n",
    "        if i % print_freq == 0:\n",
    "            print(\"Epoch [{}/{}], loss: {} acc: {}, test loss: {} test acc: {}\".format(i, num_epochs,\n",
    "                    np.round(losses[iter_no-1], 3), np.round(accs[iter_no-1]/batch_size, 3), \n",
    "                    np.round(losses_test[i+1], 3), np.round(accs_test[i+1]/A_test.shape[0], 3)))\n",
    "            \n",
    "    return losses, accs/batch_size, losses_test, accs_test/A_test.shape[0], times, model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load Binary Classification Data, CIFAR-100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# cifar-10 -- using the version downloaded from \"http://www.cs.toronto.edu/~kriz/cifar.html\"\n",
    "def unpickle(file): \n",
    "    import pickle\n",
    "    with open(file, 'rb') as fo:\n",
    "        dict = pickle.load(fo, encoding='bytes')\n",
    "    return dict\n",
    "\n",
    "directory = '' # CHANGE DIRECTORY HERE\n",
    "\n",
    "train = unpickle(directory + \"train\")\n",
    "A = train[b'data'].astype(np.float64)\n",
    "train_labels = np.array(train[b'fine_labels'])\n",
    "y = train_labels.copy().reshape((train_labels.shape[0], 1))\n",
    "\n",
    "# test set\n",
    "batch = unpickle(directory + \"test\")\n",
    "A_test = batch[b'data'].astype(np.float64)\n",
    "test_labels = np.array(batch[b'fine_labels'])\n",
    "y_test = test_labels.copy().reshape((test_labels.shape[0], 1))\n",
    "\n",
    "# to get the first two classes only (for binary classification)\n",
    "inds = np.argwhere(y <= 1)[:,0] # get the classes 0 and 1\n",
    "A = A[inds, :]\n",
    "y = y[inds].reshape((inds.shape[0], 1))\n",
    "\n",
    "inds_test = np.argwhere(y_test <= 1)[:,0] # get the classes 0 and 1\n",
    "A_test = A_test[inds_test, :]\n",
    "y_test = y_test[inds_test].reshape((inds_test.shape[0], 1))\n",
    "\n",
    "A_test /= 255\n",
    "A /= 255\n",
    "\n",
    "# demean the training data and also remove the training mean from the test data\n",
    "A_test -= np.mean(A, 0)\n",
    "A -= np.mean(A, 0)\n",
    "\n",
    "n, d = A.shape\n",
    "print(A.shape, y.shape, A_test.shape, y_test.shape)\n",
    "A_conv = A.reshape(A.shape[0], 3, 32, 32)\n",
    "A_conv_test = A_test.reshape(A_test.shape[0], 3, 32, 32)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define Network Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "beta = 1e-4\n",
    "num_neurons = 1000\n",
    "num_epochs = 501\n",
    "learning_rate = 1e-6"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train GD network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_gd_deep = batchnorm_full_gd_deep(A_conv, y, A_conv_test, \n",
    "                             y_test, num_epochs, num_neurons, beta, \n",
    "                             learning_rate, verbose=True, bn_momentum=1.0,\n",
    "                              print_freq=20, device='cpu')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train SGD Network with Various Batch Sizes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_sgd_bs50 = batchnorm_full_gd_deep(A_conv, y, A_conv_test, \n",
    "                             y_test, 461, num_neurons, beta, \n",
    "                             learning_rate, verbose=True, \n",
    "                            bn_momentum=0.1,batch_size=50,\n",
    "                                print_freq=20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_sgd_bs100 = batchnorm_full_gd_deep(A_conv, y, A_conv_test, \n",
    "                             y_test, num_epochs, num_neurons, beta, \n",
    "                             learning_rate, verbose=True, \n",
    "                            bn_momentum=0.1,batch_size=100,\n",
    "                                print_freq=20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_sgd_bs500 = batchnorm_full_gd_deep(A_conv, y, A_conv_test, \n",
    "                             y_test, num_epochs, num_neurons, beta, \n",
    "                             learning_rate, verbose=True, \n",
    "                            bn_momentum=0.1,batch_size=500,\n",
    "                                print_freq=20)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Plot Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(8, 6))\n",
    "plt.rcParams.update({'font.size': 16})\n",
    "plt.xlabel('Time (sec)');  plt.grid()\n",
    "\n",
    "ylabel_list = [\"Cost\", \"Accuracy\"]\n",
    "plt.ylabel(ylabel_list[1])\n",
    "# plt.title('Test Loss: FC Networks with BN before ReLU, Conv Sign Patterns--CIFAR-10 Classification, n={}, d={}, P={}'.format(n, d, num_neurons))\n",
    "\n",
    "plt.plot(results_sgd_bs50[-2][::20]- results_sgd_bs50[-2][0],\n",
    "         results_sgd_bs50[3], label=\"Test--bs=50\",\n",
    "        linewidth=3,linestyle='dashed', color='b')\n",
    "plt.plot(results_sgd_bs100[-2][::10]- results_sgd_bs100[-2][0],\n",
    "         results_sgd_bs100[3], label=\"Test--bs=100\",\n",
    "        linewidth=3,linestyle='dashed', color='orange')\n",
    "plt.plot(results_sgd_bs500[-2][::2]- results_sgd_bs500[-2][0],\n",
    "         results_sgd_bs500[3], label=\"Test--bs=500\",\n",
    "        linewidth=3,linestyle='dashed', color='green')\n",
    "plt.plot(results_gd_deep[-2]- results_gd_deep[-2][0],\n",
    "         results_gd_deep[3], label=\"Test--bs=1000\",\n",
    "        linewidth=3, linestyle='dashed', color='red')\n",
    "\n",
    "ids = np.arange(len(results_sgd_bs50[1]))//20\n",
    "out = np.bincount(ids,results_sgd_bs50[1])/np.bincount(ids)\n",
    "\n",
    "time_out = np.bincount(ids,results_sgd_bs50[-2][1:])/np.bincount(ids)\n",
    "\n",
    "plt.plot(time_out- results_sgd_bs50[-2][0],\n",
    "         out, label=\"Train--bs=50\",\n",
    "        linewidth=3, color='b')\n",
    "\n",
    "ids = np.arange(len(results_sgd_bs100[1]))//10\n",
    "out = np.bincount(ids,results_sgd_bs100[1])/np.bincount(ids)\n",
    "time_out = np.bincount(ids,results_sgd_bs100[-2][1:])/np.bincount(ids)\n",
    "\n",
    "plt.plot(time_out- results_sgd_bs100[-2][0],\n",
    "        out, label=\"Train,bs=100\",\n",
    "        linewidth=3, color='orange')\n",
    "\n",
    "\n",
    "ids = np.arange(len(results_sgd_bs500[1]))//2\n",
    "out = np.bincount(ids,results_sgd_bs500[1])/np.bincount(ids)\n",
    "time_out = np.bincount(ids,results_sgd_bs500[-2][1:])/np.bincount(ids)\n",
    "\n",
    "plt.plot(time_out- results_sgd_bs500[-2][0],\n",
    "         out, label=\"Train--bs=500\",\n",
    "        linewidth=3, color='green')\n",
    "plt.plot(results_gd_deep[-2][1:]- results_gd_deep[-2][0],\n",
    "         results_gd_deep[1], label=\"Train--bs=1000\",\n",
    "        linewidth=3, color='red')\n",
    "plt.legend()\n",
    "\n",
    "plt.legend()\n",
    "plt.xticks([0, 10000, 20000])\n",
    "axes = plt.gca()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
