{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The code here presents experiments on the exact convex formulation and implicit regularization for two-layer BN networks where the BN is before the ReLU, as presented in Section 7 of the paper."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Preliminaries and Function Definitions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import os\n",
    "import math\n",
    "import time\n",
    "from datetime import datetime\n",
    "import pickle\n",
    "import re\n",
    "import gc\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.utils.checkpoint as cp\n",
    "from collections import OrderedDict\n",
    "from torch import Tensor\n",
    "from torch.jit.annotations import List\n",
    "import torchvision.datasets as datasets\n",
    "import torchvision.transforms as transforms\n",
    "import matplotlib.pyplot as plt\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from torch.autograd import Variable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# prepare dataset\n",
    "class PrepareData(torch.utils.data.Dataset):\n",
    "    def __init__(self, X, y):\n",
    "        if not torch.is_tensor(X):\n",
    "            self.X = torch.from_numpy(X)\n",
    "        else:\n",
    "            self.X = X\n",
    "        if not torch.is_tensor(y):\n",
    "            self.y = torch.from_numpy(y)\n",
    "        else:\n",
    "            self.y = y\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.X)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.X[idx], self.y[idx]\n",
    "    \n",
    "class PrepareData3D(Dataset):\n",
    "    def __init__(self, X, y, z):\n",
    "        if not torch.is_tensor(X):\n",
    "            self.X = torch.from_numpy(X)\n",
    "        else:\n",
    "            self.X = X\n",
    "            \n",
    "        if not torch.is_tensor(y):\n",
    "            self.y = torch.from_numpy(y)\n",
    "        else:\n",
    "            self.y = y\n",
    "        \n",
    "        if not torch.is_tensor(z):\n",
    "            self.z = torch.from_numpy(z)\n",
    "        else:\n",
    "            self.z = z\n",
    "        \n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.X)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.X[idx], self.y[idx], self.z[idx]\n",
    "    \n",
    "def one_hot(labels, num_classes=10, device='cpu'):\n",
    "    \"\"\"Embedding labels to one-hot form.\n",
    "\n",
    "    Args:\n",
    "      labels: (LongTensor) class labels, sized [N,].\n",
    "      num_classes: (int) number of classes.\n",
    "\n",
    "    Returns:\n",
    "      (tensor) encoded labels, sized [N, #classes].\n",
    "    \"\"\"\n",
    "    y = torch.eye(num_classes).to(device)\n",
    "    return y[labels.long()] "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# functions for solving the nonconvex batchnorm problem\n",
    "class BatchNormNet(torch.nn.Module):\n",
    "    def __init__(self, d, num_neurons, d_out, momentum=1.0, bias=True):\n",
    "        super(BatchNormNet, self).__init__()\n",
    "        self.layer1 = torch.nn.Linear(d, num_neurons, bias=False)\n",
    "        self.gamma = torch.nn.Parameter(data=torch.empty(num_neurons), requires_grad=True)\n",
    "        torch.nn.init.ones_(self.gamma)\n",
    "        self.eta = torch.nn.Parameter(data=torch.empty(num_neurons), requires_grad=bias)\n",
    "        torch.nn.init.zeros_(self.eta)\n",
    "        \n",
    "        self.act = torch.nn.ReLU()\n",
    "        self.layer2 = torch.nn.Linear(num_neurons, d_out, bias=False)\n",
    "        \n",
    "        self.momentum = momentum\n",
    "        \n",
    "        self.mean = torch.nn.Parameter(data=torch.zeros(num_neurons), requires_grad=False)\n",
    "        self.std = torch.nn.Parameter(data=torch.ones(num_neurons), requires_grad=False)\n",
    "\n",
    "    def forward(self, x):\n",
    "        if self.training:\n",
    "            out = self.layer1(x)\n",
    "            next_mean = torch.mean(out, 0)\n",
    "            self.mean = torch.nn.Parameter(data =self.momentum*next_mean + (1-self.momentum)*self.mean.data, requires_grad=False)\n",
    "            \n",
    "            out = out - next_mean\n",
    "            next_std = torch.norm(out, dim=0)\n",
    "            self.std = torch.nn.Parameter(data=self.momentum*next_std + (1-self.momentum)*self.std.data, requires_grad=False)\n",
    "            \n",
    "            out = out/next_std\n",
    "            out = self.layer2(self.act(out*self.gamma + self.eta))\n",
    "        else:\n",
    "            out = self.layer1(x)\n",
    "            out = out - self.mean\n",
    "            out = out/self.std\n",
    "            out = self.layer2(self.act(out*self.gamma + self.eta))\n",
    "\n",
    "        return out\n",
    "\n",
    "def loss_func(yhat, y, model, beta):\n",
    "    loss = 0.5 * torch.norm(yhat - y)**2\n",
    "    \n",
    "    # standard weight decay regularization \n",
    "    for p in model.named_parameters():\n",
    "        if p[0] in ['gamma', 'eta', 'layer2.weight']:\n",
    "            loss = loss + beta/2 * torch.norm(p[1])**2\n",
    "    return loss\n",
    "\n",
    "def validation(model, testloader, beta, device):\n",
    "    test_loss = 0\n",
    "    test_correct = 0\n",
    "    \n",
    "    model.eval()\n",
    "\n",
    "    for ix, (_x, _y) in enumerate(testloader):\n",
    "        _x = Variable(_x).float().to(device)\n",
    "        _y = Variable(_y).float().to(device)\n",
    "\n",
    "        output = model.forward(_x)\n",
    "        yhat = model(_x).float()\n",
    "\n",
    "        loss = loss_func(yhat, _y, model, beta)\n",
    "\n",
    "        test_loss += loss.item()\n",
    "        if _y.shape[1] == 1:\n",
    "            test_correct += torch.eq(yhat>0.5, _y).float().sum()\n",
    "        else:\n",
    "            test_correct += torch.eq(torch.argmax(yhat, 1), torch.argmax(_y, 1)).float().sum()\n",
    "\n",
    "    return test_loss, test_correct\n",
    "\n",
    "def batchnorm_full_gd(A_train, y_train, A_test, y_test, num_epochs, num_neurons, beta, \n",
    "                       learning_rate, verbose=False, batch_size=None, bn_momentum=1.0, \n",
    "                      print_freq=5, device='cpu', optimizer='Adam', bias=True, beta2=0.999,\n",
    "                     partial_eigenmodes=-1, return_model=True, scheduler=True):\n",
    "    device = torch.device(device)\n",
    "    # D_in is input dimension, H is hidden dimension, D_out is output dimension.\n",
    "    D_in, H, D_out = A_train.shape[1], num_neurons, y_train.shape[1]\n",
    "    \n",
    "    # use gd if no batch size is passed in \n",
    "    if batch_size is None:\n",
    "        batch_size = A_train.shape[0]\n",
    "        \n",
    "    \n",
    "    if partial_eigenmodes > 0:\n",
    "        train_means = torch.mean(A_train, dim=0)\n",
    "        u, s, v = torch.svd(A_train - train_means)\n",
    "        \n",
    "        u = u[:, :partial_eigenmodes]\n",
    "        s = s[:partial_eigenmodes]\n",
    "        v = v[:, :partial_eigenmodes]\n",
    "        \n",
    "        A_train = u @ torch.diag(s) @ v.t() + train_means\n",
    "        \n",
    "    # create the model\n",
    "    model = BatchNormNet(D_in, H, D_out, momentum=bn_momentum, bias=bias).to(device)\n",
    "    if optimizer=='Adam':\n",
    "        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, beta2))\n",
    "    elif optimizer=='SGD':\n",
    "        optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)\n",
    "    else:\n",
    "        optimizer = torch.optim.Adagrad(model.parameters(), lr=learning_rate)\n",
    "    if scheduler:\n",
    "        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, \n",
    "                                                           mode='min', \n",
    "                                                           factor=0.5, \n",
    "                                                           patience=20, \n",
    "                                                           verbose=verbose)\n",
    "    \n",
    "    # arrays for saving the loss and accuracy    \n",
    "    losses = np.zeros((int(num_epochs*np.ceil(A_train.shape[0] / batch_size))))\n",
    "    accs = np.zeros(losses.shape)\n",
    "    losses_test = np.zeros((num_epochs+1))\n",
    "    accs_test = np.zeros((num_epochs+1))\n",
    "    times = np.zeros((losses.shape[0]+1))\n",
    "    times[0] = time.time()\n",
    "\n",
    "    \n",
    "    # dataset loaders (minibatch)\n",
    "    ds = PrepareData(X=A_train, y=y_train)\n",
    "    ds = DataLoader(ds, batch_size=batch_size, shuffle=True)\n",
    "    ds_test = PrepareData(X=A_test, y=y_test)\n",
    "    ds_test = DataLoader(ds_test, batch_size=A_test.shape[0], shuffle=False)\n",
    "    \n",
    "    losses_test[0], accs_test[0] = validation(model, ds_test, beta, device) # loss on the entire test set\n",
    "    \n",
    "    iter_no = 0\n",
    "    isnan = False\n",
    "    \n",
    "    for i in range(num_epochs):\n",
    "        model.train()\n",
    "        for ix, (_x, _y) in enumerate(ds):\n",
    "            _x = Variable(_x).float().to(device)\n",
    "            _y = Variable(_y).float().to(device)\n",
    "            \n",
    "            \n",
    "            yhat = model(_x).float()\n",
    "            \n",
    "            loss = loss_func(yhat, _y, model, beta)\n",
    "            if torch.isnan(loss):\n",
    "                isnan=True\n",
    "                break\n",
    "            \n",
    "            if D_out == 1:\n",
    "                correct = torch.eq(yhat>0.5, _y).float().sum()\n",
    "            else:\n",
    "                correct = torch.eq(torch.argmax(yhat, 1), torch.argmax(_y, 1)).float().sum()\n",
    "            optimizer.zero_grad() # zero the gradients on each pass before the update\n",
    "            loss.backward() # backpropagate the loss through the model\n",
    "            optimizer.step() # update the gradients w.r.t the loss\n",
    "\n",
    "            losses[iter_no] = loss.item() # loss on the minibatch\n",
    "            accs[iter_no] = correct\n",
    "        \n",
    "            iter_no += 1\n",
    "            times[iter_no] = time.time()\n",
    "        \n",
    "        if isnan:\n",
    "            break\n",
    "        \n",
    "        if scheduler:\n",
    "            scheduler.step(loss.item())\n",
    "        # get test loss and accuracy\n",
    "        losses_test[i+1], accs_test[i+1] = validation(model, ds_test, beta, device) # loss on the entire test set\n",
    "\n",
    "        if i % print_freq == 0:\n",
    "            print(\"Epoch [{}/{}], loss: {} acc: {}, test loss: {} test acc: {}\".format(i, num_epochs,\n",
    "                    np.round(losses[iter_no-1], 3), np.round(accs[iter_no-1]/batch_size, 3), \n",
    "                    np.round(losses_test[i+1], 3), np.round(accs_test[i+1]/A_test.shape[0], 3)))\n",
    "            \n",
    "    if return_model:\n",
    "        return losses, accs/batch_size, losses_test, accs_test/A_test.shape[0], times, model.cpu()\n",
    "    else:\n",
    "        return losses, accs/batch_size, losses_test, accs_test/A_test.shape[0], times"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class custom_cvx_layer(torch.nn.Module):\n",
    "    def __init__(self, d, num_neurons, num_classes=10, bias=True, optimizer='SGD', \n",
    "                 v_init=None, v_bias_init=None, wu=None):\n",
    "        \"\"\"\n",
    "        In the constructor we instantiate two nn.Linear modules and assign them as\n",
    "        member variables.\n",
    "        \"\"\"\n",
    "        super(custom_cvx_layer, self).__init__()\n",
    "        \n",
    "        # P x d x C\n",
    "        if v_init is None:\n",
    "            self.v = torch.nn.Parameter(data=torch.zeros(num_neurons, d, num_classes), requires_grad=True)\n",
    "        else:\n",
    "            self.v = torch.nn.Parameter(data=v_init, requires_grad=True)\n",
    "        self.w = torch.nn.Parameter(data=torch.zeros(num_neurons, d, num_classes), requires_grad=True)\n",
    "        \n",
    "        if v_bias_init is None:\n",
    "            self.v_bias = torch.nn.Parameter(data=torch.zeros(num_neurons, 1, num_classes), requires_grad=bias)\n",
    "        else:\n",
    "            self.v_bias = torch.nn.Parameter(data=v_bias_init, requires_grad=bias)\n",
    "            \n",
    "        self.w_bias = torch.nn.Parameter(data=torch.zeros(num_neurons, 1, num_classes), requires_grad=bias)\n",
    "\n",
    "        if optimizer == 'SGD':\n",
    "            self.wu = wu\n",
    "        else:\n",
    "            self.wu = 1\n",
    "        \n",
    "    def forward(self, x, sign_patterns):\n",
    "        sign_patterns = sign_patterns.unsqueeze(2)\n",
    "        x = x.view(x.shape[0], -1) # n x d\n",
    "        \n",
    "        Xv_w = torch.matmul(x, self.v - self.w) + (self.v_bias - self.w_bias)/self.wu # P x N x C\n",
    "        \n",
    "        # for some reason, the permutation is necessary. not sure why\n",
    "        DXv_w = torch.mul(sign_patterns, Xv_w.permute(1, 0, 2)) #  N x P x C\n",
    "        y_pred = torch.sum(DXv_w, dim=1, keepdim=False) # N x C\n",
    "        \n",
    "        return y_pred\n",
    "    \n",
    "def loss_func_cvxproblem(yhat, y, model, _x, sign_patterns, beta, rho, device):\n",
    "    _x = _x.view(_x.shape[0], -1)\n",
    "    \n",
    "    # term 1\n",
    "    loss = 0.5 * torch.norm(yhat - y)**2\n",
    "    # term 2\n",
    "    loss = loss + beta * torch.sum(torch.norm(torch.cat((model.v, model.v_bias/model.wu), dim=1), dim=1))\n",
    "    loss = loss + beta * torch.sum(torch.norm(torch.cat((model.w, model.w_bias/model.wu), dim=1), dim=1))\n",
    "    \n",
    "    \n",
    "    if rho > 0:\n",
    "        # term 3\n",
    "        sign_patterns = sign_patterns.unsqueeze(2) # N x P x 1\n",
    "\n",
    "        Xv = torch.matmul(_x, torch.sum(model.v, dim=2, keepdim=True)) +\\\n",
    "                        torch.sum(model.v_bias, dim=2, keepdim=True)# N x d times P x d x 1 -> P x N x 1\n",
    "        DXv = torch.mul(sign_patterns, Xv.permute(1, 0, 2)) # P x N x 1\n",
    "        relu_term_v = torch.max(-2*DXv + Xv.permute(1, 0, 2), torch.Tensor([0]).to(device))\n",
    "        loss = loss + rho * torch.sum(relu_term_v)\n",
    "\n",
    "        Xw = torch.matmul(_x, torch.sum(model.w, dim=2, keepdim=True)) +\\\n",
    "                    torch.sum(model.w_bias, dim=2, keepdim=True)\n",
    "        DXw = torch.mul(sign_patterns, Xw.permute(1, 0, 2))\n",
    "        relu_term_w = torch.max(-2*DXw + Xw.permute(1, 0, 2), torch.Tensor([0]).to(device))\n",
    "        loss = loss + rho * torch.sum(relu_term_w)\n",
    "    \n",
    "    return loss\n",
    "\n",
    "def validation_cvxproblem(model, testloader, u_vectors, bias_vectors, beta, rho, device):\n",
    "    test_loss = 0\n",
    "    test_correct = 0\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for ix, (_x, _y) in enumerate(testloader):\n",
    "            _x = Variable(_x).to(device)\n",
    "            _y = Variable(_y).to(device)\n",
    "            _x = _x.view(_x.shape[0], -1)\n",
    "            _z = (torch.matmul(_x, u_vectors.to(device)) + bias_vectors.to(device) >= 0)\n",
    "\n",
    "            output = model.forward(_x, _z)\n",
    "            yhat = model(_x, _z).float()\n",
    "\n",
    "            loss = loss_func_cvxproblem(yhat, _y, model, _x, _z, beta, rho, device)\n",
    "\n",
    "            test_loss += loss.item()\n",
    "            test_correct += torch.eq(torch.argmax(yhat, dim=1), torch.argmax(_y, dim=1)).float().sum()\n",
    "\n",
    "    return test_loss, test_correct\n",
    "\n",
    "\n",
    "# solves convex batch-norm problem, where A_train and A_test are the whitened-transformed data matrices\n",
    "def sgd_solver_cvxproblem(A_train, y_train, A_test, y_test, num_epochs, num_neurons, beta, \n",
    "                       learning_rate, rho, sign_patterns, u_vectors, bias_vectors=None, \n",
    "                          batch_size=None, verbose=False, \n",
    "                          device='cpu', partial_eigenmodes=-1, \n",
    "                          print_freq=5, bias=True, optimizer='Adam',\n",
    "                         return_model=True, scheduler=True):\n",
    "    \n",
    "    device = torch.device(device)\n",
    "    \n",
    "    print('preparing data')\n",
    "    \n",
    "    train_means = torch.mean(A_train, dim=0)\n",
    "    A_train = A_train - train_means\n",
    "    A_test = A_test - train_means\n",
    "    \n",
    "    # PCA whitening--u is our new data matrix\n",
    "    u, s, v = torch.svd(A_train)\n",
    "    u_vectors = torch.from_numpy(u_vectors).float()\n",
    "    \n",
    "    if bias_vectors is None or not bias:\n",
    "        bias_vectors = torch.zeros(1, u_vectors.shape[1]).float()\n",
    "    else:\n",
    "        bias_vectors = torch.from_numpy(bias_vectors).float()\n",
    "        \n",
    "    wu = torch.mean(s**2)\n",
    "    \n",
    "    if partial_eigenmodes > 0:\n",
    "        u = u[:, :partial_eigenmodes]\n",
    "        s = s[:partial_eigenmodes]\n",
    "        v = v[:, :partial_eigenmodes]\n",
    "\n",
    "    test_data = A_test @ v @  torch.diag(1/s)\n",
    "    train_data = u\n",
    "    u_vectors = torch.diag(s) @  v.t() @ u_vectors\n",
    "\n",
    "    if partial_eigenmodes > 0:\n",
    "        sign_patterns = (train_data @ u_vectors + bias_vectors >= 0).int().t().data.numpy()\n",
    "    v_init = None\n",
    "    v_init_bias = None\n",
    "    \n",
    "\n",
    "    n, d = train_data.shape\n",
    "    # create the model\n",
    "    if batch_size is None:\n",
    "        batch_size = n\n",
    "    \n",
    "    num_classes = y_train.shape[1]\n",
    "    model = custom_cvx_layer(d, num_neurons, num_classes=num_classes, bias=bias, \n",
    "                             optimizer=optimizer, v_init=v_init, v_bias_init=v_init_bias, wu=wu).to(device)\n",
    "    \n",
    "    if optimizer == 'Adam':\n",
    "        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
    "    elif optimizer=='SGD':\n",
    "        optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)\n",
    "    else:\n",
    "        optimizer = torch.optim.Adagrad(model.parameters(), lr=learning_rate)\n",
    "        \n",
    "    if scheduler:\n",
    "        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, \n",
    "                                                           mode='min', \n",
    "                                                           factor=0.5, \n",
    "                                                           patience=20, \n",
    "                                                           verbose=verbose)\n",
    "    \n",
    "    \n",
    "    # arrays for saving the loss and accuracy \n",
    "    losses = np.zeros((int(num_epochs*np.ceil(train_data.shape[0] / batch_size))))\n",
    "    accs = np.zeros(losses.shape)\n",
    "    \n",
    "    losses_test = np.zeros((num_epochs+1))\n",
    "    accs_test = np.zeros((num_epochs+1))\n",
    "    \n",
    "    times = np.zeros((losses.shape[0]+1))\n",
    "    \n",
    "    # dataset loaders (minibatch)\n",
    "    ds = PrepareData3D(X=train_data, y=y_train, z=sign_patterns.T)\n",
    "    ds= DataLoader(ds, batch_size=batch_size, shuffle=True)\n",
    "    ds_test = PrepareData(X=test_data, y=y_test)\n",
    "    ds_test = DataLoader(ds_test, batch_size=batch_size, shuffle=False) # note batch_size\n",
    "    \n",
    "    print('starting training')\n",
    "    losses_test[0], accs_test[0], = validation_cvxproblem(model, ds_test, u_vectors,\n",
    "                                                          bias_vectors,beta, rho, device) # loss on the entire test set\n",
    "    \n",
    "    iter_no = 0\n",
    "    times[0] = time.time()\n",
    "    isnan = False\n",
    "    \n",
    "    for i in range(num_epochs):\n",
    "        \n",
    "        for ix, (_x, _y, _z) in enumerate(ds):\n",
    "            _x = Variable(_x).float().to(device)\n",
    "            _y = Variable(_y).float().to(device)\n",
    "            _z = Variable(_z).float().to(device)\n",
    "        \n",
    "            yhat = model(_x, _z).float()\n",
    "            \n",
    "            loss = loss_func_cvxproblem(yhat, _y, model, _x, _z, beta, rho, device)\n",
    "            \n",
    "            if torch.isnan(loss):\n",
    "                isnan= True\n",
    "                break\n",
    "            \n",
    "            correct = torch.eq(torch.argmax(yhat, dim=1), torch.argmax(_y, dim=1)).float().sum() # accuracy\n",
    "\n",
    "            \n",
    "            optimizer.zero_grad() # zero the gradients on each pass before the update\n",
    "            loss.backward() # backpropagate the loss through the model\n",
    "            optimizer.step() # update the gradients w.r.t the loss          \n",
    "\n",
    "            losses[iter_no] = loss.item() # loss on the minibatch\n",
    "            accs[iter_no] = correct\n",
    "        \n",
    "            iter_no += 1\n",
    "            times[iter_no] = time.time()\n",
    "        \n",
    "        if isnan:\n",
    "            break\n",
    "        \n",
    "        if scheduler:\n",
    "            scheduler.step(loss.item())\n",
    "        # get test loss and accuracy\n",
    "        losses_test[i+1], accs_test[i+1] = validation_cvxproblem(model, ds_test, u_vectors, bias_vectors,\n",
    "                                                                 beta, rho, device) # loss on the entire test set\n",
    "        \n",
    "        if i % print_freq == 0:\n",
    "            print(\"Epoch [{}/{}], TRAIN: cvx loss: {} acc: {}. TEST: cvx loss: {} acc: {}\".format(i, num_epochs,\n",
    "                    np.round(losses[iter_no-1], 3), np.round(accs[iter_no-1]/batch_size, 3), \n",
    "                    np.round(losses_test[i+1], 3), np.round(accs_test[i+1]/A_test.shape[0], 3)))\n",
    "    if return_model:\n",
    "        return losses, accs/batch_size, losses_test, accs_test/test_data.shape[0], times,model.to('cpu')\n",
    "    else:\n",
    "        return losses, accs/batch_size, losses_test, accs_test/test_data.shape[0], times"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_conv_masks(P, image_size=32,channels=3, kernel_size=3):\n",
    "    upper_left_coords = torch.randint(high=image_size-kernel_size-1, size=(P, 2))\n",
    "    upper_left_indices = image_size*upper_left_coords[:, 0] + upper_left_coords[:, 1]\n",
    "    upper_rows = [torch.arange(upper_left_indices[i], upper_left_indices[i]+kernel_size) for i in range(P)]\n",
    "    first_patch = [torch.cat([torch.arange(upper_rows[i][j], \n",
    "                                           upper_rows[i][j]+kernel_size*image_size, \n",
    "                                           image_size) for j in range(kernel_size)]) for i in range (P)]\n",
    "    all_patches = [torch.cat([torch.arange(first_patch[i][j], \n",
    "                                           first_patch[i][j]+channels*image_size**2, \n",
    "                                           image_size**2) for j in range(kernel_size**2)]).tolist() for i in range(P)]\n",
    "    mask = torch.zeros(P, channels*image_size**2)\n",
    "    mask[np.arange(P), torch.Tensor(all_patches).long().t()] = 1.0\n",
    "    return mask.t()\n",
    "\n",
    "def generate_conv_sign_patterns(A, P, image_size=32, channels=3, kernel_size=3, verbose=False): \n",
    "    # generate sign patterns\n",
    "    n, d = A.shape\n",
    "    unique_sign_pattern_list = []  # sign patterns\n",
    "    u_vector_list = []             # random vectors used to generate the sign paterns\n",
    "    A = A - torch.mean(A, 0)\n",
    "    \n",
    "    umat = np.random.normal(0, 1, (d,P))\n",
    "    masks = generate_conv_masks(P, image_size, channels, kernel_size)\n",
    "    conv_umat = umat*masks.data.numpy()\n",
    "    \n",
    "    # d x d\n",
    "    first = np.matmul(A, conv_umat)\n",
    "    biasmat = np.random.normal(0, torch.std(first), (1,P))\n",
    "    print(torch.std(first))\n",
    "    \n",
    "    sampled_sign_pattern_mat = (first +torch.from_numpy(biasmat) >= 0) # n x P\n",
    "        \n",
    "    sign_patterns = sampled_sign_pattern_mat.int().data.numpy()\n",
    "    u_vectors =conv_umat\n",
    "    bias_vectors = biasmat\n",
    "        \n",
    "    return sign_patterns.T, u_vectors, bias_vectors, masks"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load Data -- CIFAR-10 or CIFAR-100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# cifar-10 -- using the version downloaded from \"http://www.cs.toronto.edu/~kriz/cifar.html\"\n",
    "directory = ''# CHANGE DIRECTORY HERE\n",
    "import torchvision.datasets as datasets\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))\n",
    "\n",
    "# change to CIFAR-100 if using CIFAR-100\n",
    "train_dataset = datasets.CIFAR10(\n",
    "    directory, train=True, download=True,\n",
    "    transform=transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    normalize,\n",
    "]))\n",
    "\n",
    "test_dataset = datasets.CIFAR10(\n",
    "    directory, train=False, download=True,\n",
    "    transform=transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    normalize,\n",
    "]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dummy_loader = torch.utils.data.DataLoader(\n",
    "        train_dataset, batch_size=50000, shuffle=False,\n",
    "        pin_memory=True, sampler=None)\n",
    "\n",
    "print('loading train data')\n",
    "for A, y in dummy_loader:\n",
    "    break\n",
    "    \n",
    "dummy_test_loader = torch.utils.data.DataLoader(\n",
    "        test_dataset, batch_size=10000, shuffle=False,\n",
    "        pin_memory=True, sampler=None)\n",
    "\n",
    "print('loading test data')\n",
    "for A_test, y_test in dummy_test_loader:\n",
    "    break\n",
    "    \n",
    "A = A.view(A.shape[0], -1)\n",
    "A_test = A_test.view(A_test.shape[0], -1)\n",
    "\n",
    "# change to 100 classes if using CIFAR-100\n",
    "labels = one_hot(y, num_classes=10)\n",
    "labels_test = one_hot(y_test, num_classes=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "u, s, v = torch.svd(A-torch.mean(A, 0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "var_explained = np.cumsum((s.data.numpy())**2)/np.sum(s.data.numpy()**2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(6, 4))\n",
    "plt.plot(np.arange(3072), var_explained)\n",
    "plt.title('Cumulative Variance Explained of CIFAR-10 Dataset by Singular Value')\n",
    "plt.hlines( 0.95, 0, 3072, color='red', linestyles='--')\n",
    "plt.xlabel('Singular Value Index')\n",
    "plt.ylabel('Proportion of Variance Explained')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "var_explained[215]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Run the main experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizers = ['SGD', 'Adam']\n",
    "truncation_index = 215\n",
    "batch_size = 1000\n",
    "curr_lrs = [1e-4, 5e-5, 1e-5]\n",
    "batch_sizes_other = [100, 500]\n",
    "primal_epochs = 501\n",
    "dual_epochs = 201\n",
    "dual_epochs_full = 51\n",
    "num_neurons = 4096 # to replicate CIFAR-100 results, let num_neurons be 1024\n",
    "beta = 1e-4\n",
    "bn_momentum = 0.1\n",
    "# in practice, we set the hinge loss parameter to 0 to speed up our results, \n",
    "# and it performs equally well as a small rho value\n",
    "rho = 0.0 \n",
    "problems = ['ncvx_tr', 'cvx_tr','cvx', 'ncvx']\n",
    "\n",
    "conv_sign_patterns, conv_u_vectors, conv_bias_vectors, _ = generate_conv_sign_patterns(A, num_neurons, verbose=True)\n",
    "\n",
    "results= {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for optimizer in optimizers:    \n",
    "    for prob in problems:\n",
    "        if prob.endswith('tr'):\n",
    "            curr_trunc = truncation_index\n",
    "            curr_dual_epochs = dual_epochs\n",
    "        else:\n",
    "            curr_trunc = -1\n",
    "            curr_dual_epochs = dual_epochs_full\n",
    "            \n",
    "        curr_final_accs = np.zeros(len(curr_lrs))\n",
    "        for i, lr in enumerate(curr_lrs):\n",
    "            \n",
    "            curr_name = 'massrun_' + optimizer + '_' + prob + '_' + str(lr) + '_' + str(batch_size) + '_'\n",
    "            print()\n",
    "            print(curr_name)\n",
    "            print()\n",
    "            \n",
    "            if prob.startswith('ncvx'):\n",
    "                curr_result = batchnorm_full_gd(A, labels, A_test, labels_test, primal_epochs, \n",
    "                                                num_neurons, beta, lr, verbose=True, batch_size =batch_size,\n",
    "                                                bn_momentum=bn_momentum, device='cuda', optimizer=optimizer, \n",
    "                                                partial_eigenmodes=curr_trunc, print_freq=100, \n",
    "                                                return_model=False, scheduler=False)\n",
    "                \n",
    "            else:\n",
    "                \n",
    "                curr_result = sgd_solver_cvxproblem(A, labels, A_test, labels_test, curr_dual_epochs, \n",
    "                                                    num_neurons, beta, lr, rho, conv_sign_patterns, \n",
    "                                                    conv_u_vectors, conv_bias_vectors, batch_size=batch_size, \n",
    "                                                    verbose=True, device='cuda', partial_eigenmodes=curr_trunc,\n",
    "                                                    print_freq=50, bias=True, optimizer=optimizer, return_model=False,\n",
    "                                                    scheduler=False)\n",
    "            \n",
    "            results[curr_name] = curr_result\n",
    "            curr_final_accs[i] = curr_result[3][-1]\n",
    "            \n",
    "        if prob.startswith('ncvx'):\n",
    "            lr = curr_lrs[np.argmax(curr_final_accs)]\n",
    "            \n",
    "            for bs in batch_sizes_other:\n",
    "                curr_name = 'massrun_' + optimizer + '_' + prob + '_' + str(lr) + '_' + str(bs) + '_'\n",
    "                print()\n",
    "                print(curr_name)\n",
    "                print()\n",
    "            \n",
    "                \n",
    "                curr_result = batchnorm_full_gd(A, labels, A_test, labels_test, primal_epochs//2+1, \n",
    "                                                num_neurons, beta, lr, verbose=True, batch_size =bs,\n",
    "                                                bn_momentum=bn_momentum, device='cuda', optimizer=optimizer, \n",
    "                                                partial_eigenmodes=curr_trunc, print_freq=100,\n",
    "                                                return_model=False, scheduler=False)\n",
    "                \n",
    "                results[curr_name] = curr_result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Plot the Results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is just an example plot to generate the figures in our paper--plotting the test accuracies for SGD for each method."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(8, 6))\n",
    "plt.rcParams.update({'font.size': 16})\n",
    "plt.xlabel('Time (sec)');  plt.grid()\n",
    "plt.ylabel(\"Accuracy\")\n",
    "\n",
    "plt.plot(results['massrun_SGD_ncvx_1e-05_1000_'][-1][::50]- results['massrun_SGD_ncvx_1e-05_1000_'][-1][0],\n",
    "         results['massrun_SGD_ncvx_1e-05_1000_'][3], label=\"SGD (Baseline)\",\n",
    "        linewidth=3)\n",
    "plt.plot(results['massrun_SGD_cvx_5e-05_1000_'][-1][::50]- results['massrun_SGD_cvx_5e-05_1000_'][-1][0],\n",
    "         results['massrun_SGD_cvx_5e-05_1000_'][3], label=\"Convex (Ours)\",\n",
    "        linewidth=3, linestyle='dashed')\n",
    "plt.plot(results['massrun_SGD_ncvx_tr_1e-05_1000_'][-1][::50]- results['massrun_SGD_ncvx_tr_1e-05_1000_'][-1][0],\n",
    "         results['massrun_SGD_ncvx_tr_1e-05_1000_'][3], label=\"SGD-Truncated (Ours)\",\n",
    "        linewidth=3)\n",
    "plt.plot(results['massrun_SGD_cvx_tr_0.0001_1000_'][-1][::50]- results['massrun_SGD_ncvx_tr_0.0001_1000_'][-1][0],\n",
    "         results['massrun_SGD_cvx_tr_0.0001_1000_'][3], label=\"Convex-Truncated (Ours)\",\n",
    "        linewidth=3, linestyle='dashed')\n",
    "\n",
    "plt.legend()\n",
    "axes = plt.gca()\n",
    "axes.set_xlim([0, 800])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Comparison to non-BN architectures"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class FCNetwork(nn.Module):\n",
    "    def __init__(self, H, num_classes=10, input_dim=3072):\n",
    "        self.num_classes = num_classes\n",
    "        super(FCNetwork, self).__init__()\n",
    "        self.layer1 = nn.Sequential(nn.Linear(input_dim, H, bias=False), nn.ReLU())\n",
    "        self.layer2 = nn.Linear(H, num_classes, bias=False)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x.reshape(x.size(0), -1)\n",
    "        out = self.layer2(self.layer1(x))\n",
    "        return out\n",
    "\n",
    "def loss_func_primal(yhat, y, model, beta):\n",
    "    loss = 0.5 * torch.norm(yhat - y)**2\n",
    "    \n",
    "    for p in model.parameters():\n",
    "        loss += beta/2 * torch.norm(p)**2\n",
    "        \n",
    "    return loss\n",
    "\n",
    "def validation_primal(model, testloader, beta, device):\n",
    "    test_loss = 0\n",
    "    test_correct = 0\n",
    "\n",
    "    for ix, (_x, _y) in enumerate(testloader):\n",
    "        _x = Variable(_x).float().to(device)\n",
    "        _y = Variable(_y).float().to(device)\n",
    "\n",
    "        output = model.forward(_x)\n",
    "        yhat = model(_x).float()\n",
    "\n",
    "        loss = loss_func_primal(yhat, _y, model, beta)\n",
    "\n",
    "        test_loss += loss.item()\n",
    "        test_correct += torch.eq(torch.argmax(yhat, dim=1), torch.argmax(_y, dim=1)).float().sum()\n",
    "\n",
    "    return test_loss, test_correct\n",
    "\n",
    "# solves nonconvex problem\n",
    "def sgd_solver_pytorch_v2(A_train, y_train, A_test, y_test, num_epochs, num_neurons, beta, \n",
    "                         learning_rate, batch_size, optimizer='SGD', verbose=False, \n",
    "                        device='cuda', scheduler=False, \n",
    "                          return_model=False, print_freq=5):\n",
    "    \n",
    "    device = torch.device(device)\n",
    "    # D_in is input dimension, H is hidden dimension, D_out is output dimension.\n",
    "    num_classes = y_train.shape[1]\n",
    "    D_in = A_train.shape[1]\n",
    "    train_len = A_train.shape[0]\n",
    "    test_len = A_test.shape[0]\n",
    "    \n",
    "    H, D_out = num_neurons, num_classes\n",
    "    # create the model\n",
    "    model = FCNetwork(H, D_out, D_in).to(device)\n",
    "    \n",
    "    if optimizer=='Adam':\n",
    "        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
    "    elif optimizer=='SGD':\n",
    "        optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)\n",
    "    else:\n",
    "        optimizer = torch.optim.Adagrad(model.parameters(), lr=learning_rate)\n",
    "        \n",
    "   \n",
    "    \n",
    "    # arrays for saving the loss and accuracy    \n",
    "    losses = np.zeros((int(num_epochs*np.ceil(train_len / batch_size))))\n",
    "    accs = np.zeros(losses.shape)\n",
    "    losses_test = np.zeros((num_epochs+1))\n",
    "    accs_test = np.zeros((num_epochs+1))\n",
    "    times = np.zeros((losses.shape[0]+1))\n",
    "    times[0] = time.time()\n",
    "    \n",
    "    if scheduler: \n",
    "        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,\n",
    "                                                           verbose=verbose,\n",
    "                                                           factor=0.5,\n",
    "                                                           patience=20)\n",
    "        \n",
    "    ds = PrepareData(X=A_train, y=y_train)\n",
    "    ds = DataLoader(ds, batch_size=batch_size, shuffle=True)\n",
    "    ds_test = PrepareData(X=A_test, y=y_test)\n",
    "    ds_test = DataLoader(ds_test, batch_size=A_test.shape[0], shuffle=False)\n",
    "    losses_test[0], accs_test[0] = validation_primal(model, ds_test, beta, device) # loss on the entire test set\n",
    "    isnan = False\n",
    "\n",
    "    iter_no = 0\n",
    "    for i in range(num_epochs):\n",
    "        for ix, (_x, _y) in enumerate(ds):\n",
    "            #=========make input differentiable=======================\n",
    "            _x = Variable(_x).to(device)\n",
    "            _y = Variable(_y).to(device)\n",
    "            \n",
    "            #========forward pass=====================================\n",
    "            yhat = model(_x).float()\n",
    "            \n",
    "            loss = loss_func_primal(yhat, _y, model, beta)\n",
    "            if torch.isnan(loss):\n",
    "                isnan= True\n",
    "                break\n",
    "            correct = torch.eq(torch.argmax(yhat, dim=1), torch.argmax(_y, dim=1)).float().sum()\n",
    "            \n",
    "            optimizer.zero_grad() # zero the gradients on each pass before the update\n",
    "            loss.backward() # backpropagate the loss through the model\n",
    "            optimizer.step() # update the gradients w.r.t the loss\n",
    "\n",
    "            losses[iter_no] = loss.item() # loss on the minibatch\n",
    "            accs[iter_no] = correct\n",
    "        \n",
    "            iter_no += 1\n",
    "            times[iter_no] = time.time()\n",
    "        if isnan:\n",
    "            break\n",
    "        \n",
    "        # get test loss and accuracy\n",
    "        losses_test[i+1], accs_test[i+1] = validation_primal(model, ds_test, beta, device) # loss on the entire test set\n",
    "\n",
    "        if i % print_freq == 0:\n",
    "            print(\"Epoch [{}/{}], loss: {} acc: {}, test loss: {} test acc: {}\".format(i, num_epochs,\n",
    "                    np.round(losses[iter_no-1], 3), np.round(accs[iter_no-1]/batch_size, 3), \n",
    "                    np.round(losses_test[i+1], 3), np.round(accs_test[i+1]/A_test.shape[0], 3)))\n",
    "            \n",
    "        if scheduler:\n",
    "            scheduler.step(losses[iter_no-1])\n",
    "            \n",
    "    if return_model:\n",
    "        return losses, accs/batch_size, losses_test, accs_test/A_test.shape[0], times, model.cpu()\n",
    "    else:\n",
    "        return losses, accs/batch_size, losses_test, accs_test/A_test.shape[0], times"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class custom_cvx_layer_no_bn(torch.nn.Module):\n",
    "    def __init__(self, d, num_neurons, num_classes=10):\n",
    "        \"\"\"\n",
    "        In the constructor we instantiate two nn.Linear modules and assign them as\n",
    "        member variables.\n",
    "        \"\"\"\n",
    "        super(custom_cvx_layer_no_bn, self).__init__()\n",
    "        \n",
    "        # P x d x C\n",
    "        self.v = torch.nn.Parameter(data=torch.zeros(num_neurons, d, num_classes), requires_grad=True)\n",
    "        self.w = torch.nn.Parameter(data=torch.zeros(num_neurons, d, num_classes), requires_grad=True)\n",
    "\n",
    "    def forward(self, x, sign_patterns):\n",
    "        sign_patterns = sign_patterns.unsqueeze(2)\n",
    "        x = x.view(x.shape[0], -1) # n x d\n",
    "        \n",
    "        Xv_w = torch.matmul(x, self.v - self.w) # P x N x C\n",
    "        \n",
    "        # for some reason, the permutation is necessary. not sure why\n",
    "        DXv_w = torch.mul(sign_patterns, Xv_w.permute(1, 0, 2)) #  N x P x C\n",
    "        y_pred = torch.sum(DXv_w, dim=1, keepdim=False) # N x C\n",
    "        \n",
    "        return y_pred\n",
    "\n",
    "def loss_func_cvxproblem_no_bn(yhat, y, model, _x, sign_patterns, beta, rho, device):\n",
    "    _x = _x.view(_x.shape[0], -1)\n",
    "    \n",
    "    # term 1\n",
    "    loss = 0.5 * torch.norm(yhat - y)**2\n",
    "    # term 2\n",
    "    loss = loss + beta * torch.sum(torch.norm(model.v, dim=1))\n",
    "    loss = loss + beta * torch.sum(torch.norm(model.w, dim=1))\n",
    "    \n",
    "    if rho > 0:\n",
    "    # term 3\n",
    "        sign_patterns = sign_patterns.unsqueeze(2) # N x P x 1\n",
    "\n",
    "        Xv = torch.matmul(_x, torch.sum(model.v, dim=2, keepdim=True)) # N x d times P x d x 1 -> P x N x 1\n",
    "        DXv = torch.mul(sign_patterns, Xv.permute(1, 0, 2)) # P x N x 1\n",
    "        relu_term_v = torch.max(-2*DXv + Xv.permute(1, 0, 2), torch.Tensor([0]).to(device))\n",
    "        loss = loss + rho * torch.sum(relu_term_v)\n",
    "\n",
    "        Xw = torch.matmul(_x, torch.sum(model.w, dim=2, keepdim=True))\n",
    "        DXw = torch.mul(sign_patterns, Xw.permute(1, 0, 2))\n",
    "        relu_term_w = torch.max(-2*DXw + Xw.permute(1, 0, 2), torch.Tensor([0]).to(device))\n",
    "        loss = loss + rho * torch.sum(relu_term_w)\n",
    "    \n",
    "    return loss\n",
    "\n",
    "def validation_cvxproblem_no_bn(model, testloader, u_vectors, beta, rho, device):\n",
    "    test_loss = 0\n",
    "    test_correct = 0\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for ix, (_x, _y) in enumerate(testloader):\n",
    "            \n",
    "            _x = Variable(_x).to(device)\n",
    "            _y = Variable(_y).to(device)\n",
    "            _x = _x.view(_x.shape[0], -1)\n",
    "            _z = (torch.matmul(_x, u_vectors.to(device)) >= 0)\n",
    "\n",
    "            output = model.forward(_x, _z)\n",
    "            yhat = model(_x, _z).float()\n",
    "\n",
    "            loss = loss_func_cvxproblem_no_bn(yhat, _y, model, _x, _z, beta, rho, device)\n",
    "\n",
    "            test_loss += loss.item()\n",
    "            test_correct += torch.eq(torch.argmax(yhat, dim=1), torch.argmax(_y, dim=1)).float().sum()\n",
    "                                     \n",
    "    return test_loss, test_correct\n",
    "\n",
    "def sgd_solver_cvxproblem_no_bn(A_train, y_train, A_test, y_test, num_epochs, num_neurons, beta, \n",
    "                       learning_rate, rho, sign_patterns, u_vectors, batch_size=1000,\n",
    "                          optimizer='SGD', verbose=False, device='cuda',\n",
    "                               print_freq=5, return_model=False, scheduler=False):\n",
    "    device = torch.device(device)\n",
    "\n",
    "    n, d = A_train.shape[0], A_train.shape[1]\n",
    "    num_classes = y_train.shape[1]\n",
    "    # create the model\n",
    "    model = custom_cvx_layer_no_bn(d, num_neurons, num_classes).to(device)\n",
    "    \n",
    "    if optimizer == 'Adam':\n",
    "        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
    "    elif optimizer=='SGD':\n",
    "        optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)\n",
    "    else:\n",
    "        optimizer = torch.optim.Adagrad(model.parameters(), lr=learning_rate)\n",
    "    \n",
    "    # arrays for saving the loss and accuracy \n",
    "    losses = np.zeros((int(num_epochs*np.ceil(n / batch_size))))\n",
    "    accs = np.zeros(losses.shape)\n",
    "    \n",
    "    losses_test = np.zeros((num_epochs+1))\n",
    "    accs_test = np.zeros((num_epochs+1))\n",
    "    \n",
    "    times = np.zeros((losses.shape[0]+1))\n",
    "   \n",
    "    \n",
    "    if scheduler:\n",
    "        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,\n",
    "                                                           verbose=verbose,\n",
    "                                                           factor=0.5,\n",
    "                                                           patience=20)\n",
    "                                     \n",
    "    u_vectors = torch.from_numpy(u_vectors).float()\n",
    "    sign_patterns = (A_train @ u_vectors >= 0).int().t().data.numpy()\n",
    "    ds = PrepareData3D(X=A_train, y=y_train, z=sign_patterns.T)\n",
    "    ds= DataLoader(ds, batch_size=batch_size, shuffle=True)\n",
    "    \n",
    "    ds_test = PrepareData(X=A_test, y=y_test)\n",
    "    ds_test = DataLoader(ds_test, batch_size=batch_size, shuffle=False) # note batch_size\n",
    "    \n",
    "    losses_test[0], accs_test[0] = validation_cvxproblem_no_bn(model, ds_test, u_vectors, beta, rho, device) # loss on the entire test set\n",
    "    times[0] = time.time()\n",
    "    iter_no = 0\n",
    "    print('starting training')\n",
    "    isnan = False\n",
    "                                     \n",
    "    for i in range(num_epochs):\n",
    "        model.train()\n",
    "        for ix, (_x, _y, _z) in enumerate(ds):\n",
    "            #=========make input differentiable=======================\n",
    "            _x = Variable(_x).to(device)\n",
    "            _y = Variable(_y).to(device)\n",
    "            _z = Variable(_z).to(device)\n",
    "            \n",
    "            #========forward pass=====================================\n",
    "            yhat = model(_x, _z).float()\n",
    "            \n",
    "            loss = loss_func_cvxproblem_no_bn(yhat, _y, model, _x,_z, beta, rho, device)\n",
    "            if torch.isnan(loss):\n",
    "                isnan= True\n",
    "                break\n",
    "                                     \n",
    "            correct = torch.eq(torch.argmax(yhat, dim=1), torch.argmax(_y, dim=1)).float().sum() # accuracy\n",
    "\n",
    "            optimizer.zero_grad() # zero the gradients on each pass before the update\n",
    "            loss.backward() # backpropagate the loss through the model\n",
    "            optimizer.step() # update the gradients w.r.t the loss\n",
    "\n",
    "            losses[iter_no] = loss.item() # loss on the minibatch\n",
    "            accs[iter_no] = correct\n",
    "        \n",
    "            iter_no += 1\n",
    "            times[iter_no] = time.time()\n",
    "        if isnan:\n",
    "            break\n",
    "            print('nan')\n",
    "        model.eval()\n",
    "        # get test loss and accuracy\n",
    "        losses_test[i+1], accs_test[i+1] = validation_cvxproblem_no_bn(model, ds_test, u_vectors, beta, rho, device) # loss on the entire test set\n",
    "        \n",
    "        if i % print_freq == 0:\n",
    "            print(\"Epoch [{}/{}], TRAIN: cvx loss: {} acc: {}. TEST: cvx loss: {} acc: {}\".format(i, num_epochs,\n",
    "                    np.round(losses[iter_no-1], 3), np.round(accs[iter_no-1]/batch_size, 3), \n",
    "                    np.round(losses_test[i+1], 3), np.round(accs_test[i+1]/A_test.shape[0], 3)))\n",
    "        \n",
    "        if scheduler:\n",
    "            scheduler.step(losses[iter_no-1])\n",
    "        \n",
    "    if return_model:\n",
    "        return losses, accs/batch_size, losses_test, accs_test/A_test.shape[0], times,model.to('cpu')\n",
    "    else:\n",
    "        return losses, accs/batch_size, losses_test, accs_test/A_test.shape[0], times"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_baseline = {}\n",
    "curr_problems = ['cvx', 'ncvx']\n",
    "primal_learning_rates_sgd = [5e-7, 1e-6, 5e-6, 1e-5]\n",
    "dual_learning_rates_sgd = [5e-10, 1e-9, 5e-9, 1e-8]\n",
    "optimizers = ['SGD']\n",
    "dual_epochs = 51\n",
    "primal_epochs = 501\n",
    "batch_size = 1000\n",
    "rho = 0.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for optimizer in optimizers:\n",
    "    for prob in curr_problems:\n",
    "        if prob == 'cvx':\n",
    "            curr_epochs = dual_epochs\n",
    "            curr_lrs = dual_learning_rates_sgd\n",
    "        else:\n",
    "            curr_epochs = primal_epochs\n",
    "            curr_lrs = primal_learning_rates_sgd\n",
    "            \n",
    "        for i, lr in enumerate(curr_lrs):\n",
    "            \n",
    "            curr_name = 'baseline_massrun_' + optimizer + '_' + prob + '_' + str(lr) + '_' + str(batch_size) + '_'\n",
    "            print()\n",
    "            print(curr_name)\n",
    "            print()\n",
    "            \n",
    "            if prob.startswith('ncvx'):\n",
    "                curr_result = sgd_solver_pytorch_v2(A, labels, A_test, labels_test, primal_epochs, \n",
    "                                                num_neurons, beta, lr, batch_size=batch_size, verbose=True,\n",
    "                                                device='cuda', optimizer=optimizer, \n",
    "                                                print_freq=100, return_model=False, scheduler=False)\n",
    "                \n",
    "            else:\n",
    "                curr_result = sgd_solver_cvxproblem_no_bn(A, labels, A_test, labels_test, dual_epochs, \n",
    "                                    num_neurons, beta, lr, rho, conv_sign_patterns, \n",
    "                                   conv_u_vectors, batch_size=batch_size, \n",
    "                                    verbose=True, device='cuda',\n",
    "                                    print_freq=10, optimizer=optimizer, return_model=False,\n",
    "                                    scheduler=False)\n",
    "            \n",
    "            results_baseline[curr_name] = curr_result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
