{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The code here demonstrates the speed and effectiveness of the closed-form formula for vector-output BN networks when $n \\leq d$, as presented in Section 7 of the paper."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Preliminaries and Function Definitions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import os\n",
    "import math\n",
    "import time\n",
    "import pickle\n",
    "from datetime import datetime\n",
    "import re\n",
    "import gc\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.utils.checkpoint as cp\n",
    "from collections import OrderedDict\n",
    "from torch import Tensor\n",
    "from torch.jit.annotations import List\n",
    "import torchvision.datasets as datasets\n",
    "import torchvision.transforms as transforms\n",
    "import matplotlib.pyplot as plt\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from torch.autograd import Variable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# prepare dataset\n",
    "class PrepareData(torch.utils.data.Dataset):\n",
    "    def __init__(self, X, y):\n",
    "        if not torch.is_tensor(X):\n",
    "            self.X = torch.from_numpy(X)\n",
    "        else:\n",
    "            self.X = X\n",
    "        if not torch.is_tensor(y):\n",
    "            self.y = torch.from_numpy(y)\n",
    "        else:\n",
    "            self.y = y\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.X)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.X[idx], self.y[idx]\n",
    "    \n",
    "def one_hot(labels, num_classes=10):\n",
    "    \"\"\"Embedding labels to one-hot form.\n",
    "\n",
    "    Args:\n",
    "      labels: (LongTensor) class labels, sized [N,].\n",
    "      num_classes: (int) number of classes.\n",
    "\n",
    "    Returns:\n",
    "      (tensor) encoded labels, sized [N, #classes].\n",
    "    \"\"\"\n",
    "    y = torch.eye(num_classes) \n",
    "    return y[labels] "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# functions for solving the nonconvex batchnorm problem\n",
    "class BatchNormNet(torch.nn.Module):\n",
    "    def __init__(self, d, num_neurons, d_out, momentum=1.0, n=1500):\n",
    "        super(BatchNormNet, self).__init__()\n",
    "        self.layer1 = torch.nn.Linear(d, num_neurons, bias=False)\n",
    "        self.gamma = torch.nn.Parameter(data=torch.empty(num_neurons), requires_grad=True)\n",
    "        torch.nn.init.ones_(self.gamma)\n",
    "        self.eta = torch.nn.Parameter(data=torch.empty(num_neurons), requires_grad=True)\n",
    "        torch.nn.init.zeros_(self.eta)\n",
    "        \n",
    "        self.act = torch.nn.ReLU()\n",
    "        self.layer2 = torch.nn.Linear(num_neurons, d_out, bias=False)\n",
    "        \n",
    "        self.momentum = momentum\n",
    "        self.n = n\n",
    "        \n",
    "        self.mean = torch.nn.Parameter(data=torch.zeros(num_neurons), requires_grad=False)\n",
    "        self.std = torch.nn.Parameter(data=torch.ones(num_neurons), requires_grad=False)\n",
    "\n",
    "    def forward(self, x):\n",
    "        if self.training:\n",
    "            out = self.layer1(x)\n",
    "            next_mean = torch.mean(out, 0)\n",
    "            self.mean = torch.nn.Parameter(data =self.momentum*next_mean + (1-self.momentum)*self.mean.data, requires_grad=False)\n",
    "            \n",
    "            out = out - next_mean\n",
    "            next_std = torch.norm(out, dim=0)\n",
    "            self.std = torch.nn.Parameter(data=self.momentum*next_std + (1-self.momentum)*self.std.data, requires_grad=False)\n",
    "            \n",
    "            out = out/next_std\n",
    "            out = self.layer2(self.act(out*self.gamma + self.eta/np.sqrt(n)))\n",
    "        else:\n",
    "            out = self.layer1(x)\n",
    "            out = out - self.mean\n",
    "            out = out/self.std\n",
    "            out = self.layer2(self.act(out*self.gamma + self.eta/np.sqrt(n)))\n",
    "\n",
    "        return out\n",
    "\n",
    "def loss_func(yhat, y, model, beta):\n",
    "    loss = 0.5 * torch.norm(yhat - y)**2\n",
    "    \n",
    "    # standard weight decay regularization \n",
    "    for p in model.named_parameters():\n",
    "        if p[0] in ['gamma', 'eta', 'layer2.weight']:\n",
    "            loss = loss + beta/2 * torch.norm(p[1])**2\n",
    "    return loss\n",
    "\n",
    "def validation(model, testloader, beta):\n",
    "    test_loss = 0\n",
    "    test_correct = 0\n",
    "    \n",
    "    model.eval()\n",
    "\n",
    "    for ix, (_x, _y) in enumerate(testloader):\n",
    "        _x = Variable(_x).float()\n",
    "        _y = Variable(_y).float()\n",
    "\n",
    "        output = model.forward(_x)\n",
    "        yhat = model(_x).float()\n",
    "\n",
    "        loss = loss_func(yhat, _y, model, beta)\n",
    "\n",
    "        test_loss += loss.item()\n",
    "        if _y.shape[1] == 1:\n",
    "            test_correct += torch.eq(yhat>0.5, _y).float().sum()\n",
    "        else:\n",
    "            test_correct += torch.eq(torch.argmax(yhat, 1), torch.argmax(_y, 1)).float().sum()\n",
    "\n",
    "    return test_loss, test_correct\n",
    "\n",
    "def batchnorm_full_gd(A_train, y_train, A_test, y_test, num_epochs, num_neurons, beta, \n",
    "                       learning_rate, verbose=False, batch_size=None, bn_momentum=1.0, \n",
    "                      print_freq=5, device='cpu'):\n",
    "    device = torch.device('cpu')\n",
    "    # D_in is input dimension, H is hidden dimension, D_out is output dimension.\n",
    "    D_in, H, D_out = A_train.shape[1], num_neurons, y_train.shape[1]\n",
    "    \n",
    "    # use gd if no batch size is passed in \n",
    "    if batch_size is None:\n",
    "        batch_size = A_train.shape[0]\n",
    "        \n",
    "    # create the model\n",
    "    model = BatchNormNet(D_in, H, D_out, momentum=bn_momentum).to(device)\n",
    "    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)\n",
    "    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, \n",
    "                                                           mode='min', \n",
    "                                                           factor=0.5, \n",
    "                                                           patience=2*(print_freq//5), \n",
    "                                                           verbose=verbose)\n",
    "    \n",
    "    # arrays for saving the loss and accuracy    \n",
    "    losses = np.zeros((int(num_epochs*np.ceil(A_train.shape[0] / batch_size))))\n",
    "    accs = np.zeros(losses.shape)\n",
    "    losses_test = np.zeros((num_epochs+1))\n",
    "    accs_test = np.zeros((num_epochs+1))\n",
    "    times = np.zeros((losses.shape[0]+1))\n",
    "    times[0] = time.time()\n",
    "\n",
    "    \n",
    "    # dataset loaders (minibatch)\n",
    "    ds = PrepareData(X=A_train, y=y_train)\n",
    "    ds = DataLoader(ds, batch_size=batch_size, shuffle=False)\n",
    "    ds_test = PrepareData(X=A_test, y=y_test)\n",
    "    ds_test = DataLoader(ds_test, batch_size=A_test.shape[0], shuffle=False)\n",
    "    \n",
    "    losses_test[0], accs_test[0] = validation(model, ds_test, beta) # loss on the entire test set\n",
    "    \n",
    "    iter_no = 0\n",
    "    for i in range(num_epochs):\n",
    "        model.train()\n",
    "        for ix, (_x, _y) in enumerate(ds):\n",
    "            _x = Variable(_x).float()\n",
    "            _y = Variable(_y).float()\n",
    "            \n",
    "            \n",
    "            yhat = model(_x).float()\n",
    "            \n",
    "            loss = loss_func(yhat, _y, model, beta)\n",
    "            if D_out == 1:\n",
    "                correct = torch.eq(yhat>0.5, _y).float().sum()\n",
    "            else:\n",
    "                correct = torch.eq(torch.argmax(yhat, 1), torch.argmax(_y, 1)).float().sum()\n",
    "            optimizer.zero_grad() # zero the gradients on each pass before the update\n",
    "            loss.backward() # backpropagate the loss through the model\n",
    "            optimizer.step() # update the gradients w.r.t the loss\n",
    "\n",
    "            losses[iter_no] = loss.item() # loss on the minibatch\n",
    "            accs[iter_no] = correct\n",
    "        \n",
    "            iter_no += 1\n",
    "            times[iter_no] = time.time()\n",
    "        \n",
    "        scheduler.step(loss.item())\n",
    "        # get test loss and accuracy\n",
    "        losses_test[i+1], accs_test[i+1] = validation(model, ds_test, beta) # loss on the entire test set\n",
    "\n",
    "        if i % print_freq == 0:\n",
    "            print(\"Epoch [{}/{}], loss: {} acc: {}, test loss: {} test acc: {}\".format(i, num_epochs,\n",
    "                    np.round(losses[iter_no-1], 3), np.round(accs[iter_no-1]/batch_size, 3), \n",
    "                    np.round(losses_test[i+1], 3), np.round(accs_test[i+1]/A_test.shape[0], 3)))\n",
    "            \n",
    "    return losses, accs/batch_size, losses_test, accs_test/A_test.shape[0], times, model.to('cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def closed_form_solution(A_train, y_train, A_test, y_test, beta):\n",
    "    n = A_train.shape[0]\n",
    "    n_test = A_test.shape[0]\n",
    "    pinv = torch.pinverse(A_train)\n",
    "    \n",
    "    # w_j^1 = X^dagger, y_j\n",
    "    W_1 = pinv @ y_train\n",
    "    \n",
    "    # w_j^2 = (\\|y_j\\|_2 - \\beta)_+ e_j\n",
    "    W_2 = torch.diag(F.relu(torch.norm(y_train, dim=0) - beta))\n",
    "    \n",
    "    # gamma_j = 1/\\|y_j\\|_2 * \\|y_j - 1/n 11^T y_j\\|_2\n",
    "    gamma = 1/torch.norm(y_train, dim=0) * torch.norm(y_train-torch.mean(y_train, dim=0), dim=0)\n",
    "    \n",
    "    # alpha_j = 1/\\|y_j\\|_2 * 1/sqrt(n) 1^T y_j \n",
    "    alpha = 1/torch.norm(y_train, dim=0) * torch.mean(y_train, dim=0)*np.sqrt(n)\n",
    "    \n",
    "    # compute prediction. The training output of the first layer will always be the labels, so \n",
    "    # we plug this in directly rather than the pseudo-inverse formula due to numerical stability issues.\n",
    "    first =  y\n",
    "    yhat = F.relu((first-torch.mean(first, dim=0)) / torch.norm((first-torch.mean(first, dim=0)), dim=0) @ torch.diag(gamma) + alpha.unsqueeze(0)/np.sqrt(n)) @ W_2\n",
    "    \n",
    "    # compute loss\n",
    "    train_loss = 0.5 * torch.norm(yhat - y_train)**2\n",
    "    train_loss +=  beta * torch.sum(torch.norm(W_2, dim=0))\n",
    "    \n",
    "    #compute accuracy\n",
    "    train_acc = torch.eq(torch.argmax(yhat, 1), torch.argmax(y_train, 1)).float().sum()/n\n",
    "    \n",
    "    #compute test prediction\n",
    "    first_test = A_test @ W_1\n",
    "    yhat_test = F.relu((first_test-torch.mean(first_test, dim=0)) / torch.norm((first_test-torch.mean(first_test, dim=0)),dim=0) @torch.diag(gamma) + alpha.unsqueeze(0)/np.sqrt(n)) @ W_2\n",
    "    \n",
    "    # compute test loss\n",
    "    test_loss = 0.5 * torch.norm(yhat_test - y_test)**2\n",
    "    test_loss +=  beta * torch.sum(torch.norm(W_2, dim=0))\n",
    "    \n",
    "    # compute test accuracy\n",
    "    test_acc = torch.eq(torch.argmax(yhat_test, 1), torch.argmax(y_test, 1)).float().sum()/n_test\n",
    "    \n",
    "    return train_loss, train_acc, test_loss, test_acc"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load Data--Three-class classification, CIFAR-100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1500, 3072]) torch.Size([1500, 3]) torch.Size([300, 3072]) torch.Size([300, 3])\n"
     ]
    }
   ],
   "source": [
    "# cifar-100 -- using the version downloaded from \"http://www.cs.toronto.edu/~kriz/cifar.html\"\n",
    "def unpickle(file): \n",
    "    import pickle\n",
    "    with open(file, 'rb') as fo:\n",
    "        dict = pickle.load(fo, encoding='bytes')\n",
    "    return dict\n",
    "directory = '' # CHANGE DIRECTORY HERE\n",
    "\n",
    "train = unpickle(directory + \"train\")\n",
    "A = train[b'data'].astype(np.float64)\n",
    "train_labels = np.array(train[b'fine_labels'])\n",
    "y = train_labels.copy().reshape((train_labels.shape[0], 1))\n",
    "\n",
    "# test set\n",
    "batch = unpickle(directory + \"test\")\n",
    "A_test = batch[b'data'].astype(np.float64)\n",
    "test_labels = np.array(batch[b'fine_labels'])\n",
    "y_test = test_labels.copy().reshape((test_labels.shape[0], 1))\n",
    "\n",
    "# to get the first three classes only\n",
    "inds = np.argwhere(y <= 2)[:,0] # get the classes 0, 1 and 2\n",
    "A = A[inds, :]\n",
    "y = y[inds].reshape(inds.shape[0])\n",
    "\n",
    "inds_test = np.argwhere(y_test <= 2)[:,0] # get the classes 0, 1, and 2\n",
    "A_test = A_test[inds_test, :]\n",
    "y_test = y_test[inds_test].reshape(inds_test.shape[0])\n",
    "\n",
    "A_test = torch.Tensor(A_test/255)\n",
    "A = torch.Tensor(A/255)\n",
    "\n",
    "y = one_hot(y, 3)\n",
    "y_test = one_hot(y_test, 3)\n",
    "\n",
    "n, d = A.shape\n",
    "print(A.shape, y.shape, A_test.shape, y_test.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Define Model Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "beta = 1.0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Solve the problem using the closed-form expression and GD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "closed_form_start = time.time()\n",
    "result_closed_form = closed_form_solution(A, y, A_test, y_test, beta)\n",
    "closed_form_time = time.time()-closed_form_start"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "repeated_runs = []\n",
    "num_repeated_runs = 5\n",
    "num_epochs = 40001\n",
    "learning_rate = 1e-5\n",
    "bn_momentum = 1.0\n",
    "num_neurons = 1000\n",
    "\n",
    "for i in range(num_repeated_runs):\n",
    "    results_gd_current = batchnorm_full_gd(A, y, A_test, \n",
    "                                 y_test, num_epochs, num_neurons, beta, \n",
    "                                 learning_rate, verbose=True, bn_momentum=bn_momentum,\n",
    "                                  print_freq=2000, device='cuda')\n",
    "    repeated_runs.append(list(results_gd_current)[:-1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Plot the results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(8, 6))\n",
    "plt.grid()\n",
    "plt.rcParams.update({'font.size': 16})\n",
    "plt.hlines(result_closed_form[0].item()/A.shape[0], xmin=0, xmax=repeated_runs_lower_lr[0][4][-1]-repeated_runs_lower_lr[0][4][0],\n",
    "           color='orange',  linewidth=3,linestyle='dashed', label='Theory')\n",
    "plt.scatter(closed_form_time, result_closed_form[0].item()//A.shape[0], marker='x', color='orange', s=150)\n",
    "\n",
    "for i in range(len(repeated_runs)):\n",
    "    plt.plot(repeated_runs[i][4][1:] -repeated_run[i][4][0], repeated_runs[i][0]//A.shape[0], linewidth=3, label='GD, Trial '+str(i+1))\n",
    "plt.ylabel('Objective Function')\n",
    "plt.xlabel('Time (s)')\n",
    "axes = plt.gca()\n",
    "axes.set_ylim([0, 0.5])\n",
    "plt.legend()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
