{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0b345871",
   "metadata": {},
   "source": [
    "# Preliminaries and Function Definitions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9316cb15",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import os\n",
    "import math\n",
    "import time\n",
    "from datetime import datetime\n",
    "import pickle\n",
    "import re\n",
    "import gc\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.utils.checkpoint as cp\n",
    "from collections import OrderedDict\n",
    "from torch import Tensor\n",
    "from torch.jit.annotations import List\n",
    "import torchvision.datasets as datasets\n",
    "import torchvision.transforms as transforms\n",
    "import matplotlib.pyplot as plt\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from torch.autograd import Variable\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb19c087",
   "metadata": {},
   "outputs": [],
   "source": [
    "# prepare dataset\n",
    "class PrepareData(torch.utils.data.Dataset):\n",
    "    def __init__(self, X, y):\n",
    "        if not torch.is_tensor(X):\n",
    "            self.X = torch.from_numpy(X)\n",
    "        else:\n",
    "            self.X = X\n",
    "        if not torch.is_tensor(y):\n",
    "            self.y = torch.from_numpy(y)\n",
    "        else:\n",
    "            self.y = y\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.X)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.X[idx], self.y[idx]\n",
    "    \n",
    "class PrepareData3D(Dataset):\n",
    "    def __init__(self, X, y, z):\n",
    "        if not torch.is_tensor(X):\n",
    "            self.X = torch.from_numpy(X)\n",
    "        else:\n",
    "            self.X = X\n",
    "            \n",
    "        if not torch.is_tensor(y):\n",
    "            self.y = torch.from_numpy(y)\n",
    "        else:\n",
    "            self.y = y\n",
    "        \n",
    "        if not torch.is_tensor(z):\n",
    "            self.z = torch.from_numpy(z)\n",
    "        else:\n",
    "            self.z = z\n",
    "        \n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.X)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.X[idx], self.y[idx], self.z[idx]\n",
    "    \n",
    "def one_hot(labels, num_classes=10, device='cpu'):\n",
    "    \"\"\"Embedding labels to one-hot form.\n",
    "\n",
    "    Args:\n",
    "      labels: (LongTensor) class labels, sized [N,].\n",
    "      num_classes: (int) number of classes.\n",
    "\n",
    "    Returns:\n",
    "      (tensor) encoded labels, sized [N, #classes].\n",
    "    \"\"\"\n",
    "    y = torch.eye(num_classes).to(device)\n",
    "    return y[labels.long()] "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c2abb21",
   "metadata": {},
   "outputs": [],
   "source": [
    "# functions for solving the nonconvex batchnorm problem\n",
    "class BatchNormNet(torch.nn.Module):\n",
    "    def __init__(self, d, num_neurons, d_out, momentum=1.0, bias=True):\n",
    "        super(BatchNormNet, self).__init__()\n",
    "        self.layer1 = torch.nn.Linear(d, num_neurons, bias=False)\n",
    "        self.gamma = torch.nn.Parameter(data=torch.empty(num_neurons), requires_grad=True)\n",
    "        torch.nn.init.ones_(self.gamma)\n",
    "        self.eta = torch.nn.Parameter(data=torch.empty(num_neurons), requires_grad=bias)\n",
    "        torch.nn.init.zeros_(self.eta)\n",
    "        \n",
    "        self.act = torch.nn.ReLU()\n",
    "        self.layer2 = torch.nn.Linear(num_neurons, d_out, bias=False)\n",
    "        \n",
    "        self.momentum = momentum\n",
    "        \n",
    "        self.mean = torch.nn.Parameter(data=torch.zeros(num_neurons), requires_grad=False)\n",
    "        self.std = torch.nn.Parameter(data=torch.ones(num_neurons), requires_grad=False)\n",
    "\n",
    "    def forward(self, x):\n",
    "        if self.training:\n",
    "            out = self.layer1(x)\n",
    "            next_mean = torch.mean(out, 0)\n",
    "            self.mean = torch.nn.Parameter(data =self.momentum*next_mean + (1-self.momentum)*self.mean.data, requires_grad=False)\n",
    "            \n",
    "            out = out - next_mean\n",
    "            next_std = torch.norm(out, dim=0)\n",
    "            self.std = torch.nn.Parameter(data=self.momentum*next_std + (1-self.momentum)*self.std.data, requires_grad=False)\n",
    "            \n",
    "            out = out/next_std\n",
    "            out = self.layer2(self.act(out*self.gamma + self.eta))\n",
    "        else:\n",
    "            out = self.layer1(x)\n",
    "            out = out - self.mean\n",
    "            out = out/self.std\n",
    "            out = self.layer2(self.act(out*self.gamma + self.eta))\n",
    "\n",
    "        return out\n",
    "\n",
    "def loss_func(yhat, y, model, beta):\n",
    "    loss = 0.5 * torch.norm(yhat - y)**2\n",
    "    \n",
    "    # standard weight decay regularization \n",
    "    for p in model.named_parameters():\n",
    "        if p[0] in ['gamma', 'eta', 'layer2.weight']:\n",
    "            loss = loss + beta/2 * torch.norm(p[1])**2\n",
    "    return loss\n",
    "\n",
    "def validation(model, testloader, beta, device):\n",
    "    test_loss = 0\n",
    "    test_correct = 0\n",
    "    \n",
    "    model.eval()\n",
    "\n",
    "    for ix, (_x, _y) in enumerate(testloader):\n",
    "        _x = Variable(_x).float().to(device)\n",
    "        _y = Variable(_y).float().to(device)\n",
    "\n",
    "        output = model.forward(_x)\n",
    "        yhat = model(_x).float()\n",
    "\n",
    "        loss = loss_func(yhat, _y, model, beta)\n",
    "\n",
    "        test_loss += loss.item()\n",
    "        if _y.shape[1] == 1:\n",
    "            test_correct += torch.eq(yhat>0.5, _y).float().sum()\n",
    "        else:\n",
    "            test_correct += torch.eq(torch.argmax(yhat, 1), torch.argmax(_y, 1)).float().sum()\n",
    "\n",
    "    return test_loss, test_correct\n",
    "\n",
    "def batchnorm_full_gd(A_train, y_train, A_test, y_test, num_epochs, num_neurons, beta, \n",
    "                       learning_rate, verbose=False, batch_size=None, bn_momentum=1.0, \n",
    "                      print_freq=5, device='cpu', optimizer='Adam', bias=True, beta2=0.999,\n",
    "                     partial_eigenmodes=-1, return_model=True, scheduler=True):\n",
    "    device = torch.device(device)\n",
    "    # D_in is input dimension, H is hidden dimension, D_out is output dimension.\n",
    "    D_in, H, D_out = A_train.shape[1], num_neurons, y_train.shape[1]\n",
    "    \n",
    "    # use gd if no batch size is passed in \n",
    "    if batch_size is None:\n",
    "        batch_size = A_train.shape[0]\n",
    "        \n",
    "    \n",
    "    if partial_eigenmodes > 0:\n",
    "        train_means = torch.mean(A_train, dim=0)\n",
    "        u, s, v = torch.svd(A_train - train_means)\n",
    "        \n",
    "        u = u[:, :partial_eigenmodes]\n",
    "        s = s[:partial_eigenmodes]\n",
    "        v = v[:, :partial_eigenmodes]\n",
    "        \n",
    "        A_train = u @ torch.diag(s) @ v.t() + train_means\n",
    "        \n",
    "    # create the model\n",
    "    model = BatchNormNet(D_in, H, D_out, momentum=bn_momentum, bias=bias).to(device)\n",
    "    if optimizer=='Adam':\n",
    "        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, beta2))\n",
    "    elif optimizer=='SGD':\n",
    "        optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)\n",
    "    else:\n",
    "        optimizer = torch.optim.Adagrad(model.parameters(), lr=learning_rate)\n",
    "    if scheduler:\n",
    "        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, \n",
    "                                                           mode='min', \n",
    "                                                           factor=0.5, \n",
    "                                                           patience=20, \n",
    "                                                           verbose=verbose)\n",
    "    \n",
    "    # arrays for saving the loss and accuracy    \n",
    "    losses = np.zeros((int(num_epochs*np.ceil(A_train.shape[0] / batch_size))))\n",
    "    accs = np.zeros(losses.shape)\n",
    "    losses_test = np.zeros((num_epochs+1))\n",
    "    accs_test = np.zeros((num_epochs+1))\n",
    "    times = np.zeros((losses.shape[0]+1))\n",
    "    times[0] = time.time()\n",
    "\n",
    "    \n",
    "    # dataset loaders (minibatch)\n",
    "    ds = PrepareData(X=A_train, y=y_train)\n",
    "    ds = DataLoader(ds, batch_size=batch_size, shuffle=True)\n",
    "    ds_test = PrepareData(X=A_test, y=y_test)\n",
    "    ds_test = DataLoader(ds_test, batch_size=A_test.shape[0], shuffle=False)\n",
    "    \n",
    "    losses_test[0], accs_test[0] = validation(model, ds_test, beta, device) # loss on the entire test set\n",
    "    \n",
    "    iter_no = 0\n",
    "    isnan = False\n",
    "    \n",
    "    for i in range(num_epochs):\n",
    "        model.train()\n",
    "        for ix, (_x, _y) in enumerate(ds):\n",
    "            _x = Variable(_x).float().to(device)\n",
    "            _y = Variable(_y).float().to(device)\n",
    "            \n",
    "            \n",
    "            yhat = model(_x).float()\n",
    "            \n",
    "            loss = loss_func(yhat, _y, model, beta)\n",
    "            if torch.isnan(loss):\n",
    "                isnan=True\n",
    "                break\n",
    "            \n",
    "            if D_out == 1:\n",
    "                correct = torch.eq(yhat>0.5, _y).float().sum()\n",
    "            else:\n",
    "                correct = torch.eq(torch.argmax(yhat, 1), torch.argmax(_y, 1)).float().sum()\n",
    "            optimizer.zero_grad() # zero the gradients on each pass before the update\n",
    "            loss.backward() # backpropagate the loss through the model\n",
    "            optimizer.step() # update the gradients w.r.t the loss\n",
    "\n",
    "            losses[iter_no] = loss.item() # loss on the minibatch\n",
    "            accs[iter_no] = correct\n",
    "        \n",
    "            iter_no += 1\n",
    "            times[iter_no] = time.time()\n",
    "        \n",
    "        if isnan:\n",
    "            break\n",
    "        \n",
    "        if scheduler:\n",
    "            scheduler.step(loss.item())\n",
    "        # get test loss and accuracy\n",
    "        losses_test[i+1], accs_test[i+1] = validation(model, ds_test, beta, device) # loss on the entire test set\n",
    "\n",
    "        if i % print_freq == 0:\n",
    "            print(\"Epoch [{}/{}], loss: {} acc: {}, test loss: {} test acc: {}\".format(i, num_epochs,\n",
    "                    np.round(losses[iter_no-1], 3), np.round(accs[iter_no-1]/batch_size, 3), \n",
    "                    np.round(losses_test[i+1], 3), np.round(accs_test[i+1]/A_test.shape[0], 3)))\n",
    "            \n",
    "    if return_model:\n",
    "        return losses, accs/batch_size, losses_test, accs_test/A_test.shape[0], times, model.cpu()\n",
    "    else:\n",
    "        return losses, accs/batch_size, losses_test, accs_test/A_test.shape[0], times"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df87cdfa",
   "metadata": {},
   "outputs": [],
   "source": [
    "class custom_cvx_layer(torch.nn.Module):\n",
    "    def __init__(self, d, num_neurons, num_classes=10, bias=True, optimizer='SGD', \n",
    "                 v_init=None, v_bias_init=None, wu=None):\n",
    "        \"\"\"\n",
    "        In the constructor we instantiate two nn.Linear modules and assign them as\n",
    "        member variables.\n",
    "        \"\"\"\n",
    "        super(custom_cvx_layer, self).__init__()\n",
    "        \n",
    "        # P x d x C\n",
    "        if v_init is None:\n",
    "            self.v = torch.nn.Parameter(data=torch.zeros(num_neurons, d, num_classes), requires_grad=True)\n",
    "        else:\n",
    "            self.v = torch.nn.Parameter(data=v_init, requires_grad=True)\n",
    "        self.w = torch.nn.Parameter(data=torch.zeros(num_neurons, d, num_classes), requires_grad=True)\n",
    "        \n",
    "        if v_bias_init is None:\n",
    "            self.v_bias = torch.nn.Parameter(data=torch.zeros(num_neurons, 1, num_classes), requires_grad=bias)\n",
    "        else:\n",
    "            self.v_bias = torch.nn.Parameter(data=v_bias_init, requires_grad=bias)\n",
    "            \n",
    "        self.w_bias = torch.nn.Parameter(data=torch.zeros(num_neurons, 1, num_classes), requires_grad=bias)\n",
    "\n",
    "        if optimizer == 'SGD':\n",
    "            self.wu = wu\n",
    "        else:\n",
    "            self.wu = 1\n",
    "        \n",
    "    def forward(self, x, sign_patterns):\n",
    "        sign_patterns = sign_patterns.unsqueeze(2)\n",
    "        x = x.view(x.shape[0], -1) # n x d\n",
    "        \n",
    "        Xv_w = torch.matmul(x, self.v - self.w) + (self.v_bias - self.w_bias)/self.wu # P x N x C\n",
    "        \n",
    "        # for some reason, the permutation is necessary. not sure why\n",
    "        DXv_w = torch.mul(sign_patterns, Xv_w.permute(1, 0, 2)) #  N x P x C\n",
    "        y_pred = torch.sum(DXv_w, dim=1, keepdim=False) # N x C\n",
    "        \n",
    "        return y_pred\n",
    "    \n",
    "def loss_func_cvxproblem(yhat, y, model, _x, sign_patterns, beta, rho, device):\n",
    "    _x = _x.view(_x.shape[0], -1)\n",
    "    \n",
    "    # term 1\n",
    "    loss = 0.5 * torch.norm(yhat - y)**2\n",
    "    # term 2\n",
    "    loss = loss + beta * torch.sum(torch.norm(torch.cat((model.v, model.v_bias/model.wu), dim=1), dim=1))\n",
    "    loss = loss + beta * torch.sum(torch.norm(torch.cat((model.w, model.w_bias/model.wu), dim=1), dim=1))\n",
    "    \n",
    "    \n",
    "    if rho > 0:\n",
    "        # term 3\n",
    "        sign_patterns = sign_patterns.unsqueeze(2) # N x P x 1\n",
    "\n",
    "        Xv = torch.matmul(_x, torch.sum(model.v, dim=2, keepdim=True)) +\\\n",
    "                        torch.sum(model.v_bias, dim=2, keepdim=True)# N x d times P x d x 1 -> P x N x 1\n",
    "        DXv = torch.mul(sign_patterns, Xv.permute(1, 0, 2)) # P x N x 1\n",
    "        relu_term_v = torch.max(-2*DXv + Xv.permute(1, 0, 2), torch.Tensor([0]).to(device))\n",
    "        loss = loss + rho * torch.sum(relu_term_v)\n",
    "\n",
    "        Xw = torch.matmul(_x, torch.sum(model.w, dim=2, keepdim=True)) +\\\n",
    "                    torch.sum(model.w_bias, dim=2, keepdim=True)\n",
    "        DXw = torch.mul(sign_patterns, Xw.permute(1, 0, 2))\n",
    "        relu_term_w = torch.max(-2*DXw + Xw.permute(1, 0, 2), torch.Tensor([0]).to(device))\n",
    "        loss = loss + rho * torch.sum(relu_term_w)\n",
    "    \n",
    "    return loss\n",
    "\n",
    "def validation_cvxproblem(model, testloader, u_vectors, bias_vectors, beta, rho, device):\n",
    "    test_loss = 0\n",
    "    test_correct = 0\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for ix, (_x, _y) in enumerate(testloader):\n",
    "            _x = Variable(_x).to(device)\n",
    "            _y = Variable(_y).to(device)\n",
    "            _x = _x.view(_x.shape[0], -1)\n",
    "            _z = (torch.matmul(_x, u_vectors.to(device)) + bias_vectors.to(device) >= 0)\n",
    "\n",
    "            output = model.forward(_x, _z)\n",
    "            yhat = model(_x, _z).float()\n",
    "\n",
    "            loss = loss_func_cvxproblem(yhat, _y, model, _x, _z, beta, rho, device)\n",
    "\n",
    "            test_loss += loss.item()\n",
    "            test_correct += torch.eq(torch.argmax(yhat, dim=1), torch.argmax(_y, dim=1)).float().sum()\n",
    "\n",
    "    return test_loss, test_correct\n",
    "\n",
    "\n",
    "# solves convex batch-norm problem, where A_train and A_test are the whitened-transformed data matrices\n",
    "def sgd_solver_cvxproblem(A_train, y_train, A_test, y_test, num_epochs, num_neurons, beta, \n",
    "                       learning_rate, rho, sign_patterns, u_vectors, bias_vectors=None, \n",
    "                          batch_size=None, verbose=False, \n",
    "                          device='cpu', partial_eigenmodes=-1, \n",
    "                          print_freq=5, bias=True, optimizer='Adam',\n",
    "                         return_model=True, scheduler=True):\n",
    "    \n",
    "    device = torch.device(device)\n",
    "    \n",
    "    print('preparing data')\n",
    "    \n",
    "    train_means = torch.mean(A_train, dim=0)\n",
    "    A_train = A_train - train_means\n",
    "    A_test = A_test - train_means\n",
    "    \n",
    "    # PCA whitening--u is our new data matrix\n",
    "    u, s, v = torch.svd(A_train)\n",
    "    u_vectors = torch.from_numpy(u_vectors).float()\n",
    "    \n",
    "    if bias_vectors is None or not bias:\n",
    "        bias_vectors = torch.zeros(1, u_vectors.shape[1]).float()\n",
    "    else:\n",
    "        bias_vectors = torch.from_numpy(bias_vectors).float()\n",
    "        \n",
    "    wu = torch.mean(s**2)\n",
    "    \n",
    "    if partial_eigenmodes > 0:\n",
    "        u = u[:, :partial_eigenmodes]\n",
    "        s = s[:partial_eigenmodes]\n",
    "        v = v[:, :partial_eigenmodes]\n",
    "\n",
    "    test_data = A_test @ v @  torch.diag(1/s)\n",
    "    train_data = u\n",
    "    u_vectors = torch.diag(s) @  v.t() @ u_vectors\n",
    "\n",
    "    if partial_eigenmodes > 0:\n",
    "        sign_patterns = (train_data @ u_vectors + bias_vectors >= 0).int().t().data.numpy()\n",
    "    v_init = None\n",
    "    v_init_bias = None\n",
    "    \n",
    "\n",
    "    n, d = train_data.shape\n",
    "    # create the model\n",
    "    if batch_size is None:\n",
    "        batch_size = n\n",
    "    \n",
    "    num_classes = y_train.shape[1]\n",
    "    model = custom_cvx_layer(d, num_neurons, num_classes=num_classes, bias=bias, \n",
    "                             optimizer=optimizer, v_init=v_init, v_bias_init=v_init_bias, wu=wu).to(device)\n",
    "    \n",
    "    if optimizer == 'Adam':\n",
    "        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
    "    elif optimizer=='SGD':\n",
    "        optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)\n",
    "    else:\n",
    "        optimizer = torch.optim.Adagrad(model.parameters(), lr=learning_rate)\n",
    "        \n",
    "    if scheduler:\n",
    "        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, \n",
    "                                                           mode='min', \n",
    "                                                           factor=0.5, \n",
    "                                                           patience=20, \n",
    "                                                           verbose=verbose)\n",
    "    \n",
    "    \n",
    "    # arrays for saving the loss and accuracy \n",
    "    losses = np.zeros((int(num_epochs*np.ceil(train_data.shape[0] / batch_size))))\n",
    "    accs = np.zeros(losses.shape)\n",
    "    \n",
    "    losses_test = np.zeros((num_epochs+1))\n",
    "    accs_test = np.zeros((num_epochs+1))\n",
    "    \n",
    "    times = np.zeros((losses.shape[0]+1))\n",
    "    \n",
    "    # dataset loaders (minibatch)\n",
    "    ds = PrepareData3D(X=train_data, y=y_train, z=sign_patterns.T)\n",
    "    ds= DataLoader(ds, batch_size=batch_size, shuffle=True)\n",
    "    ds_test = PrepareData(X=test_data, y=y_test)\n",
    "    ds_test = DataLoader(ds_test, batch_size=batch_size, shuffle=False) # note batch_size\n",
    "    \n",
    "    print('starting training')\n",
    "    losses_test[0], accs_test[0], = validation_cvxproblem(model, ds_test, u_vectors,\n",
    "                                                          bias_vectors,beta, rho, device) # loss on the entire test set\n",
    "    \n",
    "    iter_no = 0\n",
    "    times[0] = time.time()\n",
    "    isnan = False\n",
    "    \n",
    "    for i in range(num_epochs):\n",
    "        \n",
    "        for ix, (_x, _y, _z) in enumerate(ds):\n",
    "            _x = Variable(_x).float().to(device)\n",
    "            _y = Variable(_y).float().to(device)\n",
    "            _z = Variable(_z).float().to(device)\n",
    "        \n",
    "            yhat = model(_x, _z).float()\n",
    "            \n",
    "            loss = loss_func_cvxproblem(yhat, _y, model, _x, _z, beta, rho, device)\n",
    "            \n",
    "            if torch.isnan(loss):\n",
    "                isnan= True\n",
    "                break\n",
    "            \n",
    "            correct = torch.eq(torch.argmax(yhat, dim=1), torch.argmax(_y, dim=1)).float().sum() # accuracy\n",
    "\n",
    "            \n",
    "            optimizer.zero_grad() # zero the gradients on each pass before the update\n",
    "            loss.backward() # backpropagate the loss through the model\n",
    "            optimizer.step() # update the gradients w.r.t the loss          \n",
    "\n",
    "            losses[iter_no] = loss.item() # loss on the minibatch\n",
    "            accs[iter_no] = correct\n",
    "        \n",
    "            iter_no += 1\n",
    "            times[iter_no] = time.time()\n",
    "        \n",
    "        if isnan:\n",
    "            break\n",
    "        \n",
    "        if scheduler:\n",
    "            scheduler.step(loss.item())\n",
    "        # get test loss and accuracy\n",
    "        losses_test[i+1], accs_test[i+1] = validation_cvxproblem(model, ds_test, u_vectors, bias_vectors,\n",
    "                                                                 beta, rho, device) # loss on the entire test set\n",
    "        \n",
    "        if i % print_freq == 0:\n",
    "            print(\"Epoch [{}/{}], TRAIN: cvx loss: {} acc: {}. TEST: cvx loss: {} acc: {}\".format(i, num_epochs,\n",
    "                    np.round(losses[iter_no-1], 3), np.round(accs[iter_no-1]/batch_size, 3), \n",
    "                    np.round(losses_test[i+1], 3), np.round(accs_test[i+1]/A_test.shape[0], 3)))\n",
    "    if return_model:\n",
    "        return losses, accs/batch_size, losses_test, accs_test/test_data.shape[0], times,model.to('cpu')\n",
    "    else:\n",
    "        return losses, accs/batch_size, losses_test, accs_test/test_data.shape[0], times"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db8954af",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_sign_patterns(A, P,  verbose=False): \n",
    "    # generate sign patterns\n",
    "    n, d = A.shape\n",
    "    unique_sign_pattern_list = []  # sign patterns\n",
    "    u_vector_list = []             # random vectors used to generate the sign paterns\n",
    "    A = A - torch.mean(A, 0)\n",
    "    \n",
    "    umat = np.random.normal(0, 1, (d,P))\n",
    "    \n",
    "    # d x d\n",
    "    first = np.matmul(A, umat)\n",
    "    biasmat = np.random.normal(0, torch.std(first), (1,P))\n",
    "    \n",
    "    sampled_sign_pattern_mat = (first +torch.from_numpy(biasmat) >= 0) # n x P\n",
    "        \n",
    "    sign_patterns = sampled_sign_pattern_mat.int().data.numpy()\n",
    "    u_vectors = umat\n",
    "    bias_vectors = biasmat\n",
    "        \n",
    "    return sign_patterns.T, u_vectors, bias_vectors"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "31ec99d8",
   "metadata": {},
   "source": [
    "# Load Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6ccc476",
   "metadata": {},
   "outputs": [],
   "source": [
    "directory = '' # CHANGE DIRECTORY HERE\n",
    "data = np.loadtxt(os.path.join(directory, 'CNAE-9.data'), delimiter=',')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2448dc9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "lab = torch.from_numpy(data[:, 0]).float()\n",
    "data = torch.from_numpy(data[:, 1:]).float()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6438716",
   "metadata": {},
   "outputs": [],
   "source": [
    "centered_data = data - torch.mean(data, 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eae858a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_indices = np.arange(900)\n",
    "test_indices = np.arange(900, len(data))\n",
    "\n",
    "np.random.shuffle(train_indices)\n",
    "np.random.shuffle(test_indices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad52f1de",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb127fde",
   "metadata": {},
   "outputs": [],
   "source": [
    "A = centered_data[train_indices]\n",
    "A_test = centered_data[test_indices]\n",
    "labels = one_hot(lab[train_indices]-1, 9)\n",
    "labels_test = one_hot(lab[test_indices]-1, 9)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa3ba0ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "def pretty_size(size):\n",
    "    \"\"\"Pretty prints a torch.Size object\"\"\"\n",
    "    assert(isinstance(size, torch.Size))\n",
    "    return \" × \".join(map(str, size))\n",
    "\n",
    "def dump_tensors(gpu_only=True):\n",
    "    \"\"\"Prints a list of the Tensors being tracked by the garbage collector.\"\"\"\n",
    "    total_size = 0\n",
    "    for obj in gc.get_objects():\n",
    "        try:\n",
    "            if torch.is_tensor(obj):\n",
    "                if not gpu_only or obj.is_cuda:\n",
    "                    print(\"%s:%s%s %s\" % (type(obj).__name__, \n",
    "                                          \" GPU\" if obj.is_cuda else \"\",\n",
    "                                          \" pinned\" if obj.is_pinned else \"\",\n",
    "                                          pretty_size(obj.size())))\n",
    "                    total_size += obj.numel()\n",
    "            elif hasattr(obj, \"data\") and torch.is_tensor(obj.data):\n",
    "                if not gpu_only or obj.is_cuda:\n",
    "                    print(\"%s → %s:%s%s%s%s %s\" % (type(obj).__name__, \n",
    "                                                   type(obj.data).__name__, \n",
    "                                                   \" GPU\" if obj.is_cuda else \"\",\n",
    "                                                   \" pinned\" if obj.data.is_pinned else \"\",\n",
    "                                                   \" grad\" if obj.requires_grad else \"\", \n",
    "                                                   \" volatile\" if obj.volatile else \"\",\n",
    "                                                   pretty_size(obj.data.size())))                    \n",
    "                    \n",
    "                    total_size += obj.data.numel()\n",
    "                    del obj\n",
    "            del obj\n",
    "        except Exception as e:\n",
    "            pass        \n",
    "    print(\"Total size:\", total_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72d132eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "u, s, v = torch.svd(A-torch.mean(A, 0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "051f3f4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "var_explained = np.cumsum((s.data.numpy())**2)/np.sum(s.data.numpy()**2)\n",
    "plt.figure(figsize=(6, 4))\n",
    "plt.plot(np.arange(A.shape[1]), var_explained)\n",
    "plt.title('Cumulative Variance Explained of MNIST Dataset by Singular Value')\n",
    "plt.hlines( 0.95, 0, A.shape[1], color='red', linestyles='--')\n",
    "plt.xlabel('Singular Value Index')\n",
    "plt.ylabel('Proportion of Variance Explained')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f12ee7fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "var_explained[200]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "be0b5d11",
   "metadata": {},
   "source": [
    "# Main Experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7cafa55",
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizers = ['SGD']\n",
    "truncation_index = 100\n",
    "batch_size = 900\n",
    "curr_lrs = [1e-5]\n",
    "primal_epochs = 851\n",
    "dual_epochs = 251\n",
    "dual_epochs_full = 101\n",
    "num_neurons = 10000 \n",
    "beta = 1e-2\n",
    "bn_momentum = 0.1\n",
    "# in practice, we set the hinge loss parameter to 0 to speed up our results, \n",
    "# and it performs equally well as a small rho value\n",
    "rho = 0.0 \n",
    "problems = ['ncvx']\n",
    "\n",
    "sign_patterns, u_vectors, bias_vectors = generate_sign_patterns(A, num_neurons, verbose=True)\n",
    "\n",
    "results= {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f780108c",
   "metadata": {},
   "outputs": [],
   "source": [
    "for optimizer in optimizers:    \n",
    "    for prob in problems:\n",
    "        if prob.endswith('tr'):\n",
    "            curr_trunc = truncation_index\n",
    "            curr_dual_epochs = dual_epochs\n",
    "        else:\n",
    "            curr_trunc = -1\n",
    "            curr_dual_epochs = dual_epochs_full\n",
    "            \n",
    "        curr_final_accs = np.zeros(len(curr_lrs))\n",
    "        for i, lr in enumerate(curr_lrs):\n",
    "            \n",
    "            curr_name = 'massrun_' + optimizer + '_' + prob + '_' + str(lr) + '_' + str(batch_size) + '_'\n",
    "            print()\n",
    "            print(curr_name)\n",
    "            print()\n",
    "            \n",
    "            if prob.startswith('ncvx'):\n",
    "                curr_result = batchnorm_full_gd(A, labels, A_test, labels_test, primal_epochs, \n",
    "                                                num_neurons, beta, lr, verbose=True, batch_size =batch_size,\n",
    "                                                bn_momentum=bn_momentum, device='cuda', optimizer=optimizer, \n",
    "                                                partial_eigenmodes=curr_trunc, print_freq=100, \n",
    "                                                return_model=False, scheduler=False)\n",
    "                \n",
    "            else:\n",
    "                \n",
    "                curr_result = sgd_solver_cvxproblem(A, labels, A_test, labels_test, curr_dual_epochs, \n",
    "                                                    num_neurons, beta, lr, rho, sign_patterns, \n",
    "                                                    u_vectors, bias_vectors, batch_size=batch_size, \n",
    "                                                    verbose=True, device='cuda', partial_eigenmodes=curr_trunc,\n",
    "                                                    print_freq=50, bias=True, optimizer=optimizer, return_model=False,\n",
    "                                                    scheduler=False)\n",
    "            \n",
    "            results[curr_name] = curr_result\n",
    "            curr_final_accs[i] = curr_result[3][-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "026af361",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9932a743",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef2b2b3e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61423972",
   "metadata": {},
   "outputs": [],
   "source": [
    "truncation_indices = [A.shape[1], 10, 100, 200, 400]\n",
    "results_trunc= {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "073793f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "for truncation_index in truncation_indices:\n",
    "    for optimizer in ['SGD']:    \n",
    "        for prob in ['cvx_tr']:\n",
    "            if prob.endswith('tr'):\n",
    "                curr_trunc = truncation_index\n",
    "                curr_dual_epochs = dual_epochs\n",
    "            else:\n",
    "                curr_trunc = -1\n",
    "                curr_dual_epochs = dual_epochs_full\n",
    "\n",
    "            \n",
    "            for i, lr in enumerate([1e-5]):\n",
    "\n",
    "                curr_name = 'massrun_' + optimizer + '_' + prob + '_' + str(lr) + '_' + str(batch_size) + '_' + str(curr_trunc)\n",
    "                print()\n",
    "                print(curr_name)\n",
    "                print()\n",
    "\n",
    "                curr_result = sgd_solver_cvxproblem(A, labels, A_test, labels_test, curr_dual_epochs, \n",
    "                                                    num_neurons, beta, lr, rho, sign_patterns, \n",
    "                                                    u_vectors, bias_vectors, batch_size=batch_size, \n",
    "                                                    verbose=True, device='cuda', partial_eigenmodes=curr_trunc,\n",
    "                                                    print_freq=50, bias=True, optimizer=optimizer, return_model=False,\n",
    "                                                    scheduler=False)\n",
    "\n",
    "                results_trunc[curr_name] = curr_result"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b487ec21",
   "metadata": {},
   "source": [
    "# Plot the Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5963a6a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(8, 6))\n",
    "plt.rcParams.update({'font.size': 16})\n",
    "plt.xlabel('Time (sec)');  plt.grid()\n",
    "plt.ylabel(\"Accuracy\")\n",
    "\n",
    "plt.plot(results['massrun_SGD_ncvx_1e-05_900_'][-1]- results['massrun_SGD_ncvx_1e-05_900_'][-1][0],\n",
    "         results['massrun_SGD_ncvx_1e-05_900_'][3], label=\"SGD (Baseline)\",\n",
    "        linewidth=3)\n",
    "plt.plot(results_trunc['massrun_SGD_cvx_tr_1e-05_900_10'][-1]- results_trunc['massrun_SGD_cvx_tr_1e-05_900_10'][-1][0],\n",
    "         results_trunc['massrun_SGD_cvx_tr_1e-05_900_10'][3], label=\"Convex-Truncated, 10\",\n",
    "        linewidth=3, linestyle='dashed')\n",
    "plt.plot(results_trunc['massrun_SGD_cvx_tr_1e-05_900_100'][-1]- results_trunc['massrun_SGD_cvx_tr_1e-05_900_100'][-1][0],\n",
    "         results_trunc['massrun_SGD_cvx_tr_1e-05_900_100'][3], label=\"Convex-Truncated, 100\",\n",
    "        linewidth=3, linestyle='dashed')\n",
    "\n",
    "plt.plot(results_trunc['massrun_SGD_cvx_tr_1e-05_900_200'][-1][:176]- results_trunc['massrun_SGD_cvx_tr_1e-05_900_200'][-1][0],\n",
    "         results_trunc['massrun_SGD_cvx_tr_1e-05_900_200'][3][:176], label=\"Convex-Truncated, 200\",\n",
    "        linewidth=3, linestyle='dashed')\n",
    "\n",
    "plt.plot(results_trunc['massrun_SGD_cvx_tr_1e-05_900_400'][-1][:139]- results_trunc['massrun_SGD_cvx_tr_1e-05_900_400'][-1][0],\n",
    "         results_trunc['massrun_SGD_cvx_tr_1e-05_900_400'][3][:139], label=\"Convex-Truncated, 400\",\n",
    "        linewidth=3, linestyle='dashed')\n",
    "\n",
    "plt.plot(results_trunc['massrun_SGD_cvx_tr_1e-05_900_856'][-1][:93]- results_trunc['massrun_SGD_cvx_tr_1e-05_900_856'][-1][0],\n",
    "         results_trunc['massrun_SGD_cvx_tr_1e-05_900_856'][3][:93], label=\"Convex\",\n",
    "        linewidth=3, linestyle='dashed')\n",
    "\n",
    "\n",
    "\n",
    "plt.legend()\n",
    "axes = plt.gca()\n",
    "axes.set_ylim([0.65, 1.0])\n",
    "axes.set_xlim([0, 25])\n",
    "plt.savefig(\"cnae9_testacc.eps\", bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ab573e7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "751e60ad",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a03b3390",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
