{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# libraries \n",
    "import numpy as np\n",
    "import cvxpy as cp\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import time\n",
    "import scipy\n",
    "from scipy.sparse.linalg import LinearOperator\n",
    "import torch\n",
    "import sklearn.linear_model\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from torch.autograd import Variable\n",
    "import torch.nn as nn\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision.datasets as datasets\n",
    "import geotorch\n",
    "sns.set()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def zca_whitening(X):\n",
    "    U,S,V = torch.svd(X)\n",
    "    epsilon= 1e-12\n",
    "    \n",
    "    ZCAMatrix =  U @ torch.eye(len(S)) @ V.t()\n",
    "    transform = V @ torch.diag(1/(S+epsilon)) @ V.t()\n",
    "    return ZCAMatrix, transform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ReLUNetwork(nn.Module):\n",
    "    def __init__(self, H, num_classes=10, input_dim=784):\n",
    "        self.num_classes = num_classes\n",
    "        super(ReLUNetwork, self).__init__()\n",
    "        self.layer1 = nn.Sequential(nn.Linear(input_dim, H, bias=False), nn.ReLU())\n",
    "        self.layer2 = nn.Linear(H, num_classes, bias=False)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x.reshape(x.size(0), -1)\n",
    "        out = self.layer2(self.layer1(x))\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LinearNetwork(nn.Module):\n",
    "    def __init__(self, H, num_classes=10, input_dim=784):\n",
    "        self.num_classes = num_classes\n",
    "        super(LinearNetwork, self).__init__()\n",
    "        self.layer1 = nn.Sequential(nn.Linear(input_dim, H, bias=False), nn.Identity())\n",
    "        self.layer2 = nn.Linear(H, num_classes, bias=False)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x.reshape(x.size(0), -1)\n",
    "        out = self.layer2(self.layer1(x))\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PrepareData(Dataset):\n",
    "    def __init__(self, X, y):\n",
    "        if not torch.is_tensor(X):\n",
    "            self.X = torch.from_numpy(X)\n",
    "        else:\n",
    "            self.X = X\n",
    "            \n",
    "        if not torch.is_tensor(y):\n",
    "            self.y = torch.from_numpy(y)\n",
    "        else:\n",
    "            self.y = y\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.X)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.X[idx], self.y[idx]\n",
    "    \n",
    "def one_hot_signed_asymmetric(labels, num_classes=10):\n",
    "    # value of 2 for positive class, value of -1 for negative class\n",
    "    y = torch.eye(num_classes) \n",
    "    return (3*y[labels.long()] - 1)\n",
    "\n",
    "def identity(labels, num_classes=10):\n",
    "    return labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss_func_primal(yhat, y, model, beta):\n",
    "    loss = 0.5 * torch.norm(yhat - y)**2\n",
    "    \n",
    "    ## l2 norm on first layer weights, l1 squared norm on second layer\n",
    "    for layer, p in enumerate(model.parameters()):\n",
    "        loss += beta/2 * torch.norm(p)**2\n",
    "    \n",
    "    return loss\n",
    "\n",
    "# solves nonconvex problem\n",
    "def sgd_solver_pytorch(ds, num_epochs, num_neurons, beta, \n",
    "                         learning_rate, batch_size, verbose=False, \n",
    "                        num_classes=10, D_in=784, train_len=60000, hot_fn=one_hot_signed_asymmetric, \n",
    "                       model='relu',test_loader=None, return_test=False, test_len=10000, return_acc=True):\n",
    "    \n",
    "    device = torch.device('cuda')\n",
    "    # D_in is input dimension, H is hidden dimension, D_out is output dimension.\n",
    "    H, D_out = num_neurons, num_classes\n",
    "    # create the model\n",
    "    if model == 'relu':\n",
    "        model = ReLUNetwork(H, D_out, D_in).to(device)\n",
    "    else:\n",
    "        model = LinearNetwork(H, D_out, D_in).to(device)\n",
    "    \n",
    "    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.95)\n",
    "    \n",
    "    # arrays for saving the loss and accuracy    \n",
    "    losses = np.zeros((int(num_epochs*np.ceil(train_len / batch_size))))\n",
    "    accs = np.zeros(losses.shape)\n",
    "    test_losses = np.zeros((int(num_epochs)))\n",
    "    test_accs = np.zeros(test_losses.shape)\n",
    "    \n",
    "    times = np.zeros((losses.shape[0]+1))\n",
    "    times[0] = time.time()\n",
    "    \n",
    "    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,\n",
    "                                                           verbose=verbose,\n",
    "                                                           factor=0.5,\n",
    "                                                          eps=1e-12,\n",
    "                                                          patience=100)\n",
    "\n",
    "    iter_no = 0\n",
    "    for i in range(num_epochs):\n",
    "        for ix, (_x, _y) in enumerate(ds):\n",
    "            #=========make input differentiable=======================\n",
    "            _x = Variable(_x).to(device)\n",
    "            _y = Variable(_y).to(device)\n",
    "            \n",
    "            #========forward pass=====================================\n",
    "            yhat = model(_x).float()\n",
    "            \n",
    "            loss = loss_func_primal(yhat, hot_fn(_y, num_classes).to(device), model, beta)\n",
    "            if return_acc:\n",
    "                correct = torch.eq(torch.argmax(yhat, dim=1), torch.squeeze(_y)).float().sum()/len(_y)\n",
    "            \n",
    "            optimizer.zero_grad() # zero the gradients on each pass before the update\n",
    "            loss.backward() # backpropagate the loss through the model\n",
    "            optimizer.step() # update the gradients w.r.t the loss\n",
    "\n",
    "            losses[iter_no] = loss.item() # loss on the minibatch\n",
    "            if return_acc:\n",
    "                accs[iter_no] = correct\n",
    "        \n",
    "            iter_no += 1\n",
    "            times[iter_no] = time.time()\n",
    "            \n",
    "        if return_test:\n",
    "            with torch.no_grad():\n",
    "                for ix, (_x, _y) in enumerate(test_loader):\n",
    "                    _x = Variable(_x).to(device)\n",
    "                    _y = Variable(_y).to(device)\n",
    "                    \n",
    "                    yhat = model(_x).float()\n",
    "            \n",
    "                    loss = loss_func_primal(yhat, hot_fn(_y, num_classes).to(device), model, beta)\n",
    "                    if return_acc:\n",
    "                        correct = torch.eq(torch.argmax(yhat, dim=1), torch.squeeze(_y)).float().sum()\n",
    "                        test_accs[i] += correct\n",
    "\n",
    "                    test_losses[i] += loss.item()\n",
    "            test_accs[i] /= test_len\n",
    "                        \n",
    "        if i % 10 == 0:\n",
    "            if return_acc and return_test:\n",
    "                print(\"Epoch [{}/{}], train loss: {}, train acc: {}, test loss: {}, test acc: {}\".format(i, num_epochs,\n",
    "                        losses[iter_no-1], accs[iter_no-1], test_losses[i], test_accs[i]))\n",
    "            elif return_test:\n",
    "                print(\"Epoch [{}/{}], train loss: {}, test loss: {}\".format(i, num_epochs,\n",
    "                        losses[iter_no-1], test_losses[i]))\n",
    "            else:\n",
    "                print(\"Epoch [{}/{}], loss: {}\".format(i, num_epochs,\n",
    "                        losses[iter_no-1]))\n",
    "        \n",
    "        scheduler.step(losses[iter_no-1])\n",
    "            \n",
    "            \n",
    "    if return_test:\n",
    "        return losses, accs, test_losses, test_accs, times, model\n",
    "    else:\n",
    "        return losses, accs, times, model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Test CIFAR-10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# cifar 10\n",
    "directory = # CHANGE DIRECTORY HERE\n",
    "import torchvision.datasets as datasets\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "normalize = transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])\n",
    "train_dataset = datasets.CIFAR10(\n",
    "    directory, train=True, download=True,\n",
    "    transform=transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    normalize,\n",
    "]))\n",
    "\n",
    "test_dataset = datasets.CIFAR10(\n",
    "    directory, train=False, download=True,\n",
    "    transform=transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    normalize,\n",
    "]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 3000 sample subset of CIFAR-10\n",
    "dummy_loader= torch.utils.data.DataLoader(\n",
    "    train_dataset, batch_size=10000, shuffle=False,\n",
    "    pin_memory=True, sampler=None)\n",
    "for i, (A, y) in enumerate(dummy_loader):\n",
    "    if i == 0:\n",
    "        break\n",
    "A = A.view(A.shape[0], -1)\n",
    "\n",
    "# sample 10 instances from each class\n",
    "indices = []\n",
    "all_indices = np.arange(10000)\n",
    "for cls in range(10):\n",
    "    curr_class = all_indices[y == cls]\n",
    "    ten_sample = curr_class[:300]\n",
    "    indices.extend(list(ten_sample))\n",
    "\n",
    "np.random.shuffle(indices)\n",
    "A = A[indices]\n",
    "y = y[indices]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_mean = torch.mean(A, dim=0)\n",
    "B = A - torch.mean(A, dim=0)\n",
    "A_white, train_transform = zca_whitening(B)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_dummy_loader= torch.utils.data.DataLoader(\n",
    "    test_dataset, batch_size=len(test_dataset), shuffle=False,\n",
    "    pin_memory=True, sampler=None)\n",
    "for i, (A_test, y_test) in enumerate(dummy_loader):\n",
    "    if i == 0:\n",
    "        break\n",
    "        \n",
    "A_test = A_test.view(A_test.shape[0], -1)\n",
    "print(A_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "A_test_white= (A_test - train_mean) @ train_transform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# for a range of values of beta, compare the linear and relu model in both training and testing performance\n",
    "def compare_relu_linear(A_white, y, A_test, y_test, beta, num_classes= 10, hot_fn=one_hot_signed_asymmetric):\n",
    "    \n",
    "    data = PrepareData(A_white, y)\n",
    "    test_data = PrepareData(A_test, y_test)\n",
    "    batch_size = A_white.shape[0]\n",
    "    num_epochs = 400\n",
    "    num_neurons = 4000\n",
    "    learning_rate = 1e-2\n",
    "    train_loader = torch.utils.data.DataLoader(\n",
    "        data, batch_size=batch_size, shuffle=False,\n",
    "        pin_memory=True, sampler=None)\n",
    "    test_loader = torch.utils.data.DataLoader(\n",
    "        test_data, batch_size=1000, shuffle=False,\n",
    "        pin_memory=True, sampler=None)\n",
    "\n",
    "    relu_ls, relu_acc, relu_test_ls, relu_test_acc, relu_times, _ = sgd_solver_pytorch(train_loader,\n",
    "                                                              num_epochs, num_neurons, beta, \n",
    "                                                                 learning_rate, batch_size, \n",
    "                                                              D_in=A_white.shape[1], verbose=True, \n",
    "                                                                     num_classes = num_classes,\n",
    "                                                                 train_len=A.shape[0], hot_fn=hot_fn, model='relu', \n",
    "                                                              test_loader=test_loader, return_test=True, return_acc=False)\n",
    "    \n",
    "    linear_ls, linear_acc, linear_test_ls, linear_test_acc, linear_times, _ = sgd_solver_pytorch(train_loader,\n",
    "                                                              num_epochs, num_neurons, beta, \n",
    "                                                                 learning_rate, batch_size, \n",
    "                                                              D_in=A_white.shape[1], verbose=True, \n",
    "                                                                     num_classes = num_classes,\n",
    "                                                                 train_len=A.shape[0], hot_fn=hot_fn, model='linear', \n",
    "                                                              test_loader=test_loader, return_test=True, return_acc=False)\n",
    "    \n",
    "    print('ReLU train loss', relu_ls[-1])\n",
    "    print('Linear train loss', linear_ls[-1])\n",
    "    \n",
    "#     print('ReLU acc', relu_acc[-1])\n",
    "#     print('Linear acc', linear_acc[-1])\n",
    "    \n",
    "    print('ReLU test loss', relu_test_ls[-1])\n",
    "    print('Linear test loss', linear_test_ls[-1])\n",
    "    \n",
    "#     print('ReLU test acc', relu_test_acc[-1])\n",
    "#     print('Linear test acc', linear_test_acc[-1])\n",
    "    \n",
    "    return relu_times, relu_ls, relu_test_ls, relu_acc, relu_test_acc, linear_times, linear_ls, linear_test_ls, linear_acc, linear_test_acc \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# generate labels from a planted model\n",
    "generator = ReLUNetwork(4000, 10, 3072)\n",
    "planted_labels = generator(A_white)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "planted_test_labels = generator(A_test_white)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "planted_labels.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "beta_range = np.logspace(-3, 2, 6)\n",
    "relu_train_losses = np.zeros(len(beta_range))\n",
    "linear_train_losses = np.zeros(len(beta_range))\n",
    "relu_test_losses = np.zeros(len(beta_range))\n",
    "linear_test_losses = np.zeros(len(beta_range))\n",
    "\n",
    "for i, beta in enumerate(beta_range):\n",
    "    print('beta', beta)\n",
    "    results = compare_relu_linear(A_white, planted_labels, A_test_white, \n",
    "                                  planted_test_labels, beta, num_classes= 10, hot_fn=identity)\n",
    "    \n",
    "    relu_train_losses[i] = results[1][-1]\n",
    "    relu_test_losses[i] = results[2][-1]\n",
    "    linear_train_losses[i] = results[-4][-1]\n",
    "    linear_test_losses[i] = results[-3][-1]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(5, 3))\n",
    "\n",
    "plt.loglog(beta_range, relu_train_losses, 'x', markersize=10, label='ReLU network')\n",
    "plt.loglog(beta_range, linear_train_losses, '+', markersize=10, label='Linear network')\n",
    "\n",
    "plt.xlabel('beta')\n",
    "plt.ylabel('Train Loss')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(5, 3))\n",
    "\n",
    "plt.loglog(beta_range, relu_test_losses, 'x', markersize=10, label='ReLU network')\n",
    "plt.loglog(beta_range, linear_test_losses, '+', markersize=10, label='Linear network')\n",
    "\n",
    "plt.xlabel('beta')\n",
    "plt.ylabel('Test Loss')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "relu_test_losses[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "relu_test_losses[-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
