{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# libraries \n",
    "import numpy as np\n",
    "import cvxpy as cp\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import time\n",
    "import scipy\n",
    "from scipy.sparse.linalg import LinearOperator\n",
    "import torch\n",
    "import sklearn.linear_model\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from torch.autograd import Variable\n",
    "import torch.nn as nn\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision.datasets as datasets\n",
    "import geotorch\n",
    "sns.set()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Part 1: Whitened Data Matrices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def zca_whitening(X):\n",
    "    U,S,V = torch.svd(X)\n",
    "    epsilon= 1e-12\n",
    "    \n",
    "    ZCAMatrix =  U @ torch.eye(len(S)) @ V.t()\n",
    "    transform = V @ torch.diag(1/(S+epsilon)) @ V.t()\n",
    "    return ZCAMatrix, transform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class convex_relax(nn.Module):\n",
    "    def __init__(self, n):\n",
    "        super(convex_relax, self).__init__()\n",
    "        self.U = nn.Parameter(data=torch.randn((n, n))/(n**2), requires_grad=True)\n",
    "        self.eye= nn.Parameter(data=torch.eye(n), requires_grad=False)\n",
    "        self.n = n\n",
    "    \n",
    "    def forward(self, y):\n",
    "        return torch.trace(y.t() @ torch.inverse(self.eye + 2 * (self.U)) @ y)\n",
    "\n",
    "\n",
    "class constrained_net(nn.Module):\n",
    "    def __init__(self, n):\n",
    "        super(constrained_net, self).__init__()\n",
    "        self.cvx = convex_relax(n)\n",
    "        geotorch.positive_semidefinite(self.cvx, \"U\")\n",
    "\n",
    "    def forward(self, y):\n",
    "        return self.cvx( y)\n",
    "\n",
    "def train_constrained(A_white, y_hot, num_epochs, beta, rho, lr):\n",
    "    model =  constrained_net(len(A_white)).cuda()\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n",
    "    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,\n",
    "                                                           verbose=True,\n",
    "                                                           factor=0.5,\n",
    "                                                          eps=1e-12,\n",
    "                                                          patience=500)\n",
    "    \n",
    "    for epoch in range(num_epochs):\n",
    "        negativity = torch.sum(torch.max(-model.cvx.U, torch.Tensor([0]).cuda()))\n",
    "        \n",
    "        loss = 0.5*model(y_hot.cuda()) + beta**2 *torch.norm(model.cvx.U, 'nuc') + rho*negativity\n",
    "        if epoch % 500 == 0:\n",
    "            print(epoch, loss.item(), negativity.item())\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        scheduler.step(loss)\n",
    "    \n",
    "    # final projection\n",
    "    model.cvx.U[model.cvx.U< 0] = 0\n",
    "    loss_final = 0.5*model(y_hot.cuda()) + beta**2 * (torch.norm(model.cvx.U, 'nuc'))\n",
    "    print('final loss', loss_final.item())\n",
    "        \n",
    "    return model, loss_final.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# solve the convex NN program corresponding to convex relaxation in data whitened case. \n",
    "def train_constrained_cvxpy(A_white, y_hot, beta, eps=1e-4):\n",
    "    n = len(A_white)\n",
    "    d = A_white.shape[1]\n",
    "    y_hot = y_hot.data.numpy()\n",
    "    U = cp.Variable((n, n), PSD=True)\n",
    "#     U_p = cp.Variable((n, n), PSD=True)\n",
    "    objective = cp.Minimize( 0.5*cp.matrix_frac(y_hot, torch.eye(n) + 2*U) +\\\n",
    "                            beta**2 * (cp.trace(U)))\n",
    "    constraints = [U >= 0]\n",
    "    problem = cp.Problem(objective, constraints)\n",
    "    problem.solve(solver='SCS', eps=eps, max_iters=30000, verbose=True)\n",
    "    print('copositive relaxation loss', objective.value)\n",
    "    return objective.value, U.value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class FCNetwork(nn.Module):\n",
    "    def __init__(self, H, num_classes=10, input_dim=784):\n",
    "        self.num_classes = num_classes\n",
    "        super(FCNetwork, self).__init__()\n",
    "        self.layer1 = nn.Sequential(nn.Linear(input_dim, H, bias=False), nn.ReLU())\n",
    "        self.layer2 = nn.Linear(H, num_classes, bias=False)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x.reshape(x.size(0), -1)\n",
    "        out = self.layer2(self.layer1(x))\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PrepareData(Dataset):\n",
    "    def __init__(self, X, y):\n",
    "        if not torch.is_tensor(X):\n",
    "            self.X = torch.from_numpy(X)\n",
    "        else:\n",
    "            self.X = X\n",
    "            \n",
    "        if not torch.is_tensor(y):\n",
    "            self.y = torch.from_numpy(y)\n",
    "        else:\n",
    "            self.y = y\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.X)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.X[idx], self.y[idx]\n",
    "    \n",
    "def one_hot_signed(labels, num_classes=10):\n",
    "    y = torch.eye(num_classes) \n",
    "    return (2*y[labels.long()] - 1)\n",
    "\n",
    "def one_hot(labels, num_classes=10):\n",
    "    y = torch.eye(num_classes) \n",
    "    return y[labels.long()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss_func_primal(yhat, y, model, beta):\n",
    "    loss = 0.5 * torch.norm(yhat - y)**2\n",
    "    \n",
    "    ## l2 norm on first layer weights, l1 squared norm on second layer\n",
    "    for layer, p in enumerate(model.parameters()):\n",
    "        loss += beta/2 * torch.norm(p)**2\n",
    "    \n",
    "    return loss\n",
    "\n",
    "# solves nonconvex problem\n",
    "def sgd_solver_pytorch(ds, num_epochs, num_neurons, beta, \n",
    "                         learning_rate, batch_size, verbose=False, \n",
    "                        num_classes=10, D_in=784, train_len=60000, hot_fn=one_hot,\n",
    "                      test_loader=None, return_test=False, test_len=10000):\n",
    "    \n",
    "    device = torch.device('cuda')\n",
    "    # D_in is input dimension, H is hidden dimension, D_out is output dimension.\n",
    "    H, D_out = num_neurons, num_classes\n",
    "    # create the model\n",
    "    model = FCNetwork(H, D_out, D_in).to(device)\n",
    "    \n",
    "    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.95)\n",
    "    \n",
    "    # arrays for saving the loss and accuracy    \n",
    "    losses = np.zeros((int(num_epochs*np.ceil(train_len / batch_size))))\n",
    "    accs = np.zeros(losses.shape)\n",
    "    test_losses = np.zeros((int(num_epochs)))\n",
    "    test_accs = np.zeros(test_losses.shape)\n",
    "    \n",
    "    times = np.zeros((losses.shape[0]+1))\n",
    "    times[0] = time.time()\n",
    "    \n",
    "    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,\n",
    "                                                           verbose=verbose,\n",
    "                                                           factor=0.5,\n",
    "                                                          eps=1e-12,\n",
    "                                                          patience=100)\n",
    "\n",
    "    iter_no = 0\n",
    "    for i in range(num_epochs):\n",
    "        for ix, (_x, _y) in enumerate(ds):\n",
    "            #=========make input differentiable=======================\n",
    "            _x = Variable(_x).to(device)\n",
    "            _y = Variable(_y).to(device)\n",
    "            \n",
    "            #========forward pass=====================================\n",
    "            yhat = model(_x).float()\n",
    "            \n",
    "            loss = loss_func_primal(yhat, hot_fn(_y, num_classes).to(device), model, beta)\n",
    "            if return_test:\n",
    "                correct = torch.eq(torch.argmax(yhat, dim=1), torch.squeeze(_y)).float().sum()/len(_y)\n",
    "            \n",
    "            optimizer.zero_grad() # zero the gradients on each pass before the update\n",
    "            loss.backward() # backpropagate the loss through the model\n",
    "            optimizer.step() # update the gradients w.r.t the loss\n",
    "\n",
    "            losses[iter_no] = loss.item() # loss on the minibatch\n",
    "            if return_test:\n",
    "                accs[iter_no] = correct\n",
    "        \n",
    "            iter_no += 1\n",
    "            times[iter_no] = time.time()\n",
    "            \n",
    "        if return_test:\n",
    "            with torch.no_grad():\n",
    "                for ix, (_x, _y) in enumerate(test_loader):\n",
    "                    _x = Variable(_x).to(device)\n",
    "                    _y = Variable(_y).to(device)\n",
    "                    \n",
    "                    yhat = model(_x).float()\n",
    "            \n",
    "                    loss = loss_func_primal(yhat, hot_fn(_y, num_classes).to(device), model, beta)\n",
    "                    if return_test:\n",
    "                        correct = torch.eq(torch.argmax(yhat, dim=1), torch.squeeze(_y)).float().sum()\n",
    "                        test_accs[i] += correct\n",
    "\n",
    "                    test_losses[i] += loss.item()\n",
    "            test_accs[i] /= test_len\n",
    "                        \n",
    "        if i % 10 == 0:\n",
    "            if return_test:\n",
    "                print(\"Epoch [{}/{}], train loss: {}, train acc: {}, test loss: {}, test acc: {}\".format(i, num_epochs,\n",
    "                        losses[iter_no-1], accs[iter_no-1], test_losses[i], test_accs[i]))\n",
    "            else:\n",
    "                print(\"Epoch [{}/{}], loss: {}\".format(i, num_epochs,\n",
    "                        losses[iter_no-1]))\n",
    "        \n",
    "        scheduler.step(losses[iter_no-1])\n",
    "            \n",
    "            \n",
    "    if return_test:\n",
    "        return losses, accs, test_losses, test_accs, times, model\n",
    "    else:\n",
    "        return losses, accs, times, model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compare_whitened(A_white, y, beta, num_classes= 10, hot_encode = True):\n",
    "    if hot_encode:\n",
    "        hot_fn = one_hot\n",
    "    else:\n",
    "        hot_fn = one_hot_signed\n",
    "        \n",
    "    y_hot = hot_fn(y, num_classes).double()\n",
    "        \n",
    "    losses = np.zeros(3)\n",
    "    U, S, V = torch.svd(y_hot)\n",
    "    U = U.float()\n",
    "    S = S.float()\n",
    "    V = V.float()\n",
    "    eps = 1e-13\n",
    "\n",
    "    U_plus = torch.max(U, torch.tensor([0.]))\n",
    "    U_plus_norms = torch.norm(U_plus, dim=0)\n",
    "    U_resigned = U * (-2*(U_plus_norms < 0.5**(0.5)).float()+1)\n",
    "    V_resigned = V * (-2*(U_plus_norms < 0.5**(0.5)).float()+1)\n",
    "    \n",
    "    U_plus = torch.max(U_resigned, torch.tensor([0.]))\n",
    "    U_plus_norms = torch.norm(U_plus, dim=0)\n",
    "    G_dual = V_resigned\n",
    "\n",
    "    u_thresholds = torch.max(S - beta/U_plus_norms , torch.tensor([0.]))\n",
    "    U_dual = U_plus * u_thresholds\n",
    "\n",
    "    U_dual_norms = torch.norm(U_dual, dim=0)\n",
    "    u_primal_first = U_dual / (torch.sqrt(U_dual_norms)+eps)\n",
    "    v_primal_first = G_dual* (torch.sqrt(U_dual_norms))\n",
    "\n",
    "    first_model = u_primal_first @ v_primal_first.t()\n",
    "    first_model_loss = 0.5 * torch.norm(first_model - y_hot)**2 +\\\n",
    "                beta/2*(torch.norm(u_primal_first)**2 + torch.norm(v_primal_first)**2)\n",
    "\n",
    "    print('PCA training loss', first_model_loss.item())\n",
    "    losses[0] = first_model_loss.item()\n",
    "    \n",
    "    loss_2, U_value = train_constrained_cvxpy(A_white, y_hot, beta)\n",
    "    losses[1] = loss_2\n",
    "    \n",
    "    data = PrepareData(A_white, y)\n",
    "    batch_size = A_white.shape[0]\n",
    "    num_epochs = 30000\n",
    "    num_neurons = 1000\n",
    "    learning_rate = 1e-3\n",
    "    train_loader = torch.utils.data.DataLoader(\n",
    "        data, batch_size=batch_size, shuffle=False,\n",
    "        pin_memory=True, sampler=None)\n",
    "\n",
    "    ls, _, _, _ = sgd_solver_pytorch(train_loader, num_epochs, num_neurons, beta, \n",
    "                         learning_rate, batch_size, D_in=A_white.shape[1], verbose=True, \n",
    "                         num_classes = num_classes, train_len=A.shape[0], hot_fn=hot_fn)\n",
    "    \n",
    "    print('SGD nonconvex loss', ls[-1])\n",
    "    losses[2] = ls[-1]\n",
    "    \n",
    "    return losses\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Compare only SGD and maximum-margin SVD "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compare_whitened_mmsvd(A_white, y, A_test, y_test, beta, num_classes= 10):\n",
    "    y_hot = one_hot(y, num_classes).double()\n",
    "    y_hot_test = one_hot(y_test, num_classes).double()\n",
    "    \n",
    "    start = time.time()\n",
    "    U, S, V = torch.svd(y_hot)\n",
    "    U = U.float()\n",
    "    S = S.float()\n",
    "    V = V.float()\n",
    "    eps = 1e-13\n",
    "\n",
    "    U_plus = torch.max(U, torch.tensor([0.]))\n",
    "    U_plus_norms = torch.norm(U_plus, dim=0)\n",
    "    U_resigned = U * (-2*(U_plus_norms < 0.5**(0.5)).float()+1)\n",
    "    V_resigned = V * (-2*(U_plus_norms < 0.5**(0.5)).float()+1)\n",
    "    \n",
    "    U_plus = torch.max(U_resigned, torch.tensor([0.]))\n",
    "    U_plus_norms = torch.norm(U_plus, dim=0)\n",
    "    G_dual = V_resigned\n",
    "\n",
    "    u_thresholds = torch.max(S - beta/U_plus_norms , torch.tensor([0.]))\n",
    "    U_dual = U_plus * u_thresholds\n",
    "\n",
    "    U_dual_norms = torch.norm(U_dual, dim=0)\n",
    "    u_primal_first = U_dual / (torch.sqrt(U_dual_norms)+eps)\n",
    "    v_primal_first = G_dual* (torch.sqrt(U_dual_norms))\n",
    "\n",
    "    first_model = u_primal_first @ v_primal_first.t()\n",
    "    \n",
    "    \n",
    "    first_model_loss = 0.5 * torch.norm(first_model - y_hot)**2 +\\\n",
    "                beta/2*(torch.norm(u_primal_first)**2 + torch.norm(v_primal_first)**2)\n",
    "    \n",
    "    first_model_preds = torch.argmax(first_model, dim=1)\n",
    "    first_model_acc = torch.eq(first_model_preds, y).float().sum()/len(y)\n",
    "    \n",
    "    svd_time = time.time() - start\n",
    "    print('Max Margin SVD training loss', first_model_loss.item())\n",
    "    print('Max Margin SVD training accuracy', first_model_acc.item())\n",
    "\n",
    "    \n",
    "    primal_u = A_white.T @ u_primal_first\n",
    "    test_output = A_test @ primal_u @ v_primal_first.t()\n",
    "    \n",
    "    first_model_test_loss = 0.5 * torch.norm(test_output - y_hot_test)**2 +\\\n",
    "                beta/2*(torch.norm(u_primal_first)**2 + torch.norm(v_primal_first)**2)\n",
    "    \n",
    "    first_model_test_preds = torch.argmax(test_output, dim=1)\n",
    "    first_model_test_acc = torch.eq(first_model_test_preds, y_test).float().sum()/len(y_test)\n",
    "    \n",
    "    print('Max Margin SVD testing loss', first_model_test_loss.item())\n",
    "    print('Max Margin SVD testing accuracy', first_model_test_acc.item())\n",
    "\n",
    "    \n",
    "    data = PrepareData(A_white, y)\n",
    "    test_data = PrepareData(A_test, y_test)\n",
    "    batch_size = A_white.shape[0]\n",
    "    num_epochs = 400\n",
    "    num_neurons = 1000\n",
    "    learning_rate = 1e-2\n",
    "    train_loader = torch.utils.data.DataLoader(\n",
    "        data, batch_size=batch_size, shuffle=False,\n",
    "        pin_memory=True, sampler=None)\n",
    "    test_loader = torch.utils.data.DataLoader(\n",
    "        test_data, batch_size=1000, shuffle=False,\n",
    "        pin_memory=True, sampler=None)\n",
    "\n",
    "    ls, acc, test_ls, test_acc, times, _ = sgd_solver_pytorch(train_loader,\n",
    "                                                              num_epochs, num_neurons, beta, \n",
    "                                                                 learning_rate, batch_size, \n",
    "                                                              D_in=A_white.shape[1], verbose=True, \n",
    "                                                                     num_classes = num_classes,\n",
    "                                                                 train_len=A.shape[0], hot_fn=one_hot, \n",
    "                                                              test_loader=test_loader, return_test=True)\n",
    "    \n",
    "    print('SGD nonconvex train loss', ls[-1])\n",
    "    print('SGD nonconvex train acc', acc[-1])\n",
    "    \n",
    "    print('SGD nonconvex test loss', test_ls[-1])\n",
    "    print('SGD nonconvex test acc', test_acc[-1])\n",
    "    return times, ls, test_ls, acc, test_acc, first_model_loss, first_model_acc, first_model_test_loss, first_model_test_acc, svd_time\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# cifar 10\n",
    "directory = # CHANGE DIRECTORY HERE\n",
    "import torchvision.datasets as datasets\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "normalize = transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])\n",
    "train_dataset = datasets.CIFAR10(\n",
    "    directory, train=True, download=True,\n",
    "    transform=transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    normalize,\n",
    "]))\n",
    "\n",
    "test_dataset = datasets.CIFAR10(\n",
    "    directory, train=False, download=True,\n",
    "    transform=transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    normalize,\n",
    "]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 3000 sample subset of CIFAR-10\n",
    "dummy_loader= torch.utils.data.DataLoader(\n",
    "    train_dataset, batch_size=10000, shuffle=False,\n",
    "    pin_memory=True, sampler=None)\n",
    "for i, (A, y) in enumerate(dummy_loader):\n",
    "    if i == 0:\n",
    "        break\n",
    "A = A.view(A.shape[0], -1)\n",
    "\n",
    "# sample 10 instances from each class\n",
    "indices = []\n",
    "all_indices = np.arange(10000)\n",
    "for cls in range(10):\n",
    "    curr_class = all_indices[y == cls]\n",
    "    ten_sample = curr_class[:300]\n",
    "    indices.extend(list(ten_sample))\n",
    "\n",
    "np.random.shuffle(indices)\n",
    "A = A[indices]\n",
    "y = y[indices]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_mean = torch.mean(A, dim=0)\n",
    "B = A - torch.mean(A, dim=0)\n",
    "A_white, train_transform = zca_whitening(B)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "A_white.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_dummy_loader= torch.utils.data.DataLoader(\n",
    "    test_dataset, batch_size=len(test_dataset), shuffle=False,\n",
    "    pin_memory=True, sampler=None)\n",
    "for i, (A_test, y_test) in enumerate(dummy_loader):\n",
    "    if i == 0:\n",
    "        break\n",
    "        \n",
    "A_test = A_test.view(A_test.shape[0], -1)\n",
    "print(A_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "A_test_white= (A_test - train_mean) @ train_transform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "beta = 1.0\n",
    "results = compare_whitened_mmsvd(A_white, y, A_test_white, y_test, beta, 10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sgd_times = results[0][1:] - results[0][0]\n",
    "sgd_train_losses = results[1]\n",
    "pca_train_loss = results[-5]\n",
    "pca_train_time = results[-1]\n",
    "\n",
    "\n",
    "plt.figure(figsize=(8, 4))\n",
    "\n",
    "plt.plot(sgd_times, sgd_train_losses, label='SGD')\n",
    "plt.hlines(pca_train_loss,0, sgd_times[-1], linestyles='dashed', label='Optimal solution (Thm 2)')\n",
    "plt.scatter(pca_train_time, pca_train_loss, marker='x', color='red', label='Achieved time of Thm 2')\n",
    "\n",
    "plt.xlabel('Time (s)')\n",
    "plt.ylabel('Train Loss')\n",
    "plt.legend()\n",
    "plt.ylim(100, 300)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sgd_times = results[0][1:] - results[0][0]\n",
    "sgd_test_accs = results[4]\n",
    "pca_test_acc = results[-2]\n",
    "pca_train_time = results[-1]\n",
    "\n",
    "\n",
    "plt.figure(figsize=(8, 4))\n",
    "\n",
    "plt.plot(sgd_times, sgd_test_accs, label='SGD')\n",
    "plt.hlines(pca_test_acc,0, sgd_times[-1], linestyles='dashed', label='Optimal solution (Thm 2)')\n",
    "plt.scatter(pca_train_time, pca_test_acc, marker='x', color='red', label='Achieved time of Thm 2')\n",
    "\n",
    "plt.xlabel('Time (s)')\n",
    "plt.ylabel('Test Accuracy')\n",
    "plt.legend()\n",
    "plt.ylim(0.3, 0.4)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# -- using the version downloaded from \"http://www.cs.toronto.edu/~kriz/cifar.html\"\n",
    "directory = # CHANGE DIRECTORY HERE\n",
    "import torchvision.datasets as datasets\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "normalize = transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])\n",
    "train_dataset = datasets.CIFAR100(\n",
    "    directory, train=True, download=True,\n",
    "    transform=transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    normalize,\n",
    "]))\n",
    "\n",
    "test_dataset = datasets.CIFAR100(\n",
    "    directory, train=False, download=True,\n",
    "    transform=transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    normalize,\n",
    "]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 3000 sample subset of CIFAR-100\n",
    "dummy_loader= torch.utils.data.DataLoader(\n",
    "    train_dataset, batch_size=10000, shuffle=False,\n",
    "    pin_memory=True, sampler=None)\n",
    "for i, (A, y) in enumerate(dummy_loader):\n",
    "    if i == 0:\n",
    "        break\n",
    "A = A.view(A.shape[0], -1)\n",
    "\n",
    "# sample 10 instances from each class\n",
    "indices = []\n",
    "all_indices = np.arange(10000)\n",
    "for cls in range(100):\n",
    "    curr_class = all_indices[y == cls]\n",
    "    ten_sample = curr_class[:30]\n",
    "    indices.extend(list(ten_sample))\n",
    "\n",
    "np.random.shuffle(indices)\n",
    "A = A[indices]\n",
    "y = y[indices]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_mean = torch.mean(A, dim=0)\n",
    "B = A - torch.mean(A, dim=0)\n",
    "A_white, train_transform = zca_whitening(B)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_dummy_loader= torch.utils.data.DataLoader(\n",
    "    test_dataset, batch_size=len(test_dataset), shuffle=False,\n",
    "    pin_memory=True, sampler=None)\n",
    "for i, (A_test, y_test) in enumerate(dummy_loader):\n",
    "    if i == 0:\n",
    "        break\n",
    "        \n",
    "A_test = A_test.view(A_test.shape[0], -1)\n",
    "print(A_test.shape)\n",
    "A_test_white= (A_test - train_mean) @ train_transform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "beta = 5.0\n",
    "results_cifar100 = compare_whitened_mmsvd(A_white, y, A_test_white, y_test, beta, 100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sgd_times = results_cifar100[0][1:] - results_cifar100[0][0]\n",
    "sgd_train_losses = results_cifar100[1]\n",
    "pca_train_loss = results_cifar100[-5]\n",
    "pca_train_time = results_cifar100[-1]\n",
    "\n",
    "\n",
    "plt.figure(figsize=(8, 4))\n",
    "\n",
    "plt.plot(sgd_times, sgd_train_losses, label='SGD')\n",
    "plt.hlines(pca_train_loss,0, sgd_times[-1], linestyles='dashed', label='Optimal solution (Thm 2)')\n",
    "plt.scatter(pca_train_time, pca_train_loss, marker='x', color='red', label='Achieved time of Thm 2')\n",
    "\n",
    "plt.xlabel('Time (s)')\n",
    "plt.ylabel('Train Loss')\n",
    "plt.legend()\n",
    "plt.ylim(1400, 1600)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sgd_times = results_cifar100[0][1:] - results_cifar100[0][0]\n",
    "sgd_test_accs = results_cifar100[4]\n",
    "pca_test_acc = results_cifar100[-2]\n",
    "pca_train_time = results_cifar100[-1]\n",
    "\n",
    "\n",
    "plt.figure(figsize=(8, 4))\n",
    "\n",
    "plt.plot(sgd_times, sgd_test_accs, label='SGD')\n",
    "plt.hlines(pca_test_acc,0, sgd_times[-1], linestyles='dashed', label='Optimal solution (Thm 2)')\n",
    "plt.scatter(pca_train_time, pca_test_acc, marker='x', color='red', label='Achieved time of Thm 2')\n",
    "\n",
    "plt.xlabel('Time (s)')\n",
    "plt.ylabel('Test Accuracy')\n",
    "plt.legend()\n",
    "plt.ylim(0.2, 0.35)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load data -- CIFAR-10 first"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "directory =  # CHANGE DIRECTORY HERE\n",
    "import torchvision.datasets as datasets\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "normalize = transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])\n",
    "train_dataset = datasets.CIFAR10(\n",
    "    directory, train=True, download=True,\n",
    "    transform=transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    normalize,\n",
    "]))\n",
    "\n",
    "test_dataset = datasets.CIFAR10(\n",
    "    directory, train=False, download=True,\n",
    "    transform=transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    normalize,\n",
    "]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 100 sample subset of CIFAR-10\n",
    "dummy_loader= torch.utils.data.DataLoader(\n",
    "    train_dataset, batch_size=200, shuffle=False,\n",
    "    pin_memory=True, sampler=None)\n",
    "for i, (A, y) in enumerate(dummy_loader):\n",
    "    if i == 0:\n",
    "        break\n",
    "A = A.view(A.shape[0], -1)\n",
    "\n",
    "# sample 10 instances from each class\n",
    "indices = []\n",
    "all_indices = np.arange(200)\n",
    "for cls in range(10):\n",
    "    curr_class = all_indices[y == cls]\n",
    "    ten_sample = curr_class[:10]\n",
    "    indices.extend(list(ten_sample))\n",
    "\n",
    "np.random.shuffle(indices)\n",
    "A = A[indices]\n",
    "y = y[indices]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_mean = torch.mean(A, dim=0)\n",
    "B = A - torch.mean(A, dim=0)\n",
    "A_white, train_transform = zca_whitening(B)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compare_whitened_across_beta(A_white, y, beta_range, num_classes=10):\n",
    "    losses_onehot = np.zeros((len(beta_range), 3))\n",
    "    losses_onehot_signed = np.zeros((len(beta_range), 3))\n",
    "    for i, beta in enumerate(beta_range):\n",
    "        print('beta', beta)\n",
    "        print('unsigned')\n",
    "        losses_onehot[i] = compare_whitened(A_white, y, beta, num_classes, hot_encode=True)\n",
    "        print('signed')\n",
    "        losses_onehot_signed[i]= compare_whitened(A_white, y, beta, num_classes, hot_encode=False)\n",
    "        \n",
    "    return losses_onehot, losses_onehot_signed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cifar_onehot, cifaronehot_signed = compare_whitened_across_beta(A_white, y, np.logspace(-1, 2, 5), num_classes=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(8, 4))\n",
    "betas = np.logspace(-1, 2, 5)\n",
    "\n",
    "plt.loglog(betas, cifar_onehot[:, 1], 'x', markersize=10, label='Copositive relaxation')\n",
    "plt.loglog(betas, cifar_onehot[:, 2], '+', markersize=10, label='Nonconvex SGD solution')\n",
    "plt.loglog(betas, cifar_onehot[:, 0], '.', markersize=10, label='PCA solution')\n",
    "\n",
    "plt.title('Whitened CIFAR-10 classification with one-hot encoded labels')\n",
    "plt.xlabel('beta')\n",
    "plt.ylabel('Loss')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(8, 4))\n",
    "betas = np.logspace(-1, 2, 5)\n",
    "\n",
    "plt.loglog(betas, cifaronehot_signed[:, 1], 'x', markersize=10, label='Copositive relaxation')\n",
    "plt.loglog(betas, cifaronehot_signed[:, 2], '+', markersize=10, label='Nonconvex SGD solution')\n",
    "\n",
    "plt.title('Whitened CIFAR-10 classification with +1, -1 labels ')\n",
    "plt.xlabel('beta')\n",
    "plt.ylabel('Loss')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## CIFAR-100 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# mnist -- using the version downloaded from \"http://www.cs.toronto.edu/~kriz/cifar.html\"\n",
    "directory =  # CHANGE DIRECTORY HERE\n",
    "import torchvision.datasets as datasets\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "normalize = transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])\n",
    "train_dataset = datasets.CIFAR100(\n",
    "    directory, train=True, download=True,\n",
    "    transform=transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    normalize,\n",
    "]))\n",
    "\n",
    "test_dataset = datasets.CIFAR100(\n",
    "    directory, train=False, download=True,\n",
    "    transform=transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    normalize,\n",
    "]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 100 sample subset of CIFAR-100\n",
    "dummy_loader= torch.utils.data.DataLoader(\n",
    "    train_dataset, batch_size=1000, shuffle=False,\n",
    "    pin_memory=True, sampler=None)\n",
    "for i, (A, y) in enumerate(dummy_loader):\n",
    "    if i == 0:\n",
    "        break\n",
    "A = A.view(A.shape[0], -1)\n",
    "\n",
    "# sample 10 instances from each class\n",
    "indices = []\n",
    "all_indices = np.arange(1000)\n",
    "for cls in range(100):\n",
    "    curr_class = all_indices[y == cls]\n",
    "    ten_sample = curr_class[:1]\n",
    "    indices.extend(list(ten_sample))\n",
    "\n",
    "print(y[indices])\n",
    "np.random.shuffle(indices)\n",
    "A = A[indices]\n",
    "y = y[indices]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_mean = torch.mean(A, dim=0)\n",
    "B = A - torch.mean(A, dim=0)\n",
    "A_white, train_transform = zca_whitening(B)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cifar100_onehot, cifar100_onehot_signed = compare_whitened_across_beta(A_white, y, np.logspace(-1, 2, 5), num_classes=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(8, 4))\n",
    "betas = np.logspace(-1, 2, 5)\n",
    "\n",
    "plt.loglog(betas, cifar100_onehot[:, 1], 'x', markersize=10, label='Copositive relaxation')\n",
    "plt.loglog(betas, cifar100_onehot[:, 2], '+', markersize=10, label='Nonconvex SGD solution')\n",
    "plt.loglog(betas, cifar100_onehot[:, 0], '.', markersize=10, label='PCA solution')\n",
    "\n",
    "plt.title('Whitened CIFAR-100 classification with one-hot encoded labels')\n",
    "plt.xlabel('beta')\n",
    "plt.ylabel('Loss')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(8, 4))\n",
    "betas = np.logspace(-1, 2, 5)\n",
    "\n",
    "plt.loglog(betas, cifar100_onehot_signed[:, 1], 'x', markersize=10, label='Copositive relaxation')\n",
    "plt.loglog(betas, cifar100_onehot_signed[:, 2], '+', markersize=10, label='Nonconvex SGD solution')\n",
    "\n",
    "plt.title('Whitened CIFAR-100 classification with +1, -1 labels ')\n",
    "plt.xlabel('beta')\n",
    "plt.ylabel('Loss')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generative Task -- MNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "directory =  # CHANGE DIRECTORY HERE\n",
    "\n",
    "# normalize = transforms.Normalize((0.1307,), (0.3081,))\n",
    "train_dataset = datasets.MNIST(\n",
    "    directory, train=True, download=True,\n",
    "    transform=transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "#     normalize,\n",
    "]))\n",
    "\n",
    "test_dataset = datasets.MNIST(\n",
    "    directory, train=False, download=True,\n",
    "    transform=transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "#     normalize,\n",
    "]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 100 sample subset of MNIST\n",
    "dummy_loader= torch.utils.data.DataLoader(\n",
    "    train_dataset, batch_size=200, shuffle=False,\n",
    "    pin_memory=True, sampler=None)\n",
    "for i, (A, y) in enumerate(dummy_loader):\n",
    "    if i == 0:\n",
    "        break\n",
    "        \n",
    "A = A.view(A.shape[0], -1)\n",
    "\n",
    "# sample 10 instances from each class\n",
    "indices = []\n",
    "all_indices = np.arange(200)\n",
    "for cls in range(10):\n",
    "    curr_class = all_indices[y == cls]\n",
    "    ten_sample = curr_class[:10]\n",
    "    indices.extend(list(ten_sample))\n",
    "\n",
    "print(y[indices])\n",
    "np.random.shuffle(indices)\n",
    "A = A[indices]\n",
    "y = y[indices]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "noise = torch.randn(A.shape[0], 784)\n",
    "whitened_noise, trans = zca_whitening(noise)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "A"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compare_generator_whitened(noise, A, beta, lr=2e-2):\n",
    "    num_classes = A.shape[1]\n",
    "    losses = np.zeros(2)\n",
    "    \n",
    "    _, loss_2 = train_constrained(noise, A, 60000, beta, rho=1.0, lr=lr)\n",
    "    losses[0] = loss_2\n",
    "    \n",
    "    data = PrepareData(noise, A)\n",
    "    batch_size = noise.shape[0]\n",
    "    num_epochs = 60000\n",
    "    num_neurons = 1000\n",
    "    learning_rate = 5e-5\n",
    "    train_loader = torch.utils.data.DataLoader(\n",
    "        data, batch_size=batch_size, shuffle=False,\n",
    "        pin_memory=True, sampler=None)\n",
    "    \n",
    "    hot_fn = lambda x, y: x\n",
    "\n",
    "    ls, _, _, _ = sgd_solver_pytorch(train_loader, num_epochs, num_neurons, beta, \n",
    "                         learning_rate, batch_size, D_in=noise.shape[1], verbose=True, \n",
    "                         num_classes = num_classes, train_len=noise.shape[0], hot_fn=hot_fn)\n",
    "    \n",
    "    print('SGD nonconvex loss', ls[-1])\n",
    "    losses[1] = ls[-1]\n",
    "    return losses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "beta_range = np.logspace(-1,3, 5)\n",
    "losses = np.zeros((len(beta_range), 2))\n",
    "for i in range(len(beta_range)):\n",
    "    print('beta', beta_range[i])\n",
    "    losses[i, :] = compare_generator_whitened(whitened_noise, A, beta_range[i], lr=4e-2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(8, 4))\n",
    "\n",
    "plt.loglog(beta_range, losses[:, 0], 'x', markersize=10, label='Copositive relaxation')\n",
    "plt.loglog(beta_range, losses[:, 1], '+', markersize=10, label='Nonconvex SGD solution')\n",
    "\n",
    "plt.title('MNIST image generation')\n",
    "plt.xlabel('beta')\n",
    "plt.ylabel('Loss')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generative Task -- CIFAR-10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "directory =  # CHANGE DIRECTORY HERE\n",
    "\n",
    "# normalize = transforms.Normalize((0.1307,), (0.3081,))\n",
    "train_dataset = datasets.CIFAR10(\n",
    "    directory, train=True, download=True,\n",
    "    transform=transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "#     normalize,\n",
    "]))\n",
    "\n",
    "test_dataset = datasets.CIFAR10(\n",
    "    directory, train=False, download=True,\n",
    "    transform=transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "#     normalize,\n",
    "]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 100 sample subset of CIFAR\n",
    "dummy_loader= torch.utils.data.DataLoader(\n",
    "    train_dataset, batch_size=200, shuffle=False,\n",
    "    pin_memory=True, sampler=None)\n",
    "for i, (A, y) in enumerate(dummy_loader):\n",
    "    if i == 0:\n",
    "        break\n",
    "A = A.view(A.shape[0], -1)\n",
    "\n",
    "# sample 10 instances from each class\n",
    "indices = []\n",
    "all_indices = np.arange(200)\n",
    "for cls in range(10):\n",
    "    curr_class = all_indices[y == cls]\n",
    "    ten_sample = curr_class[:10]\n",
    "    indices.extend(list(ten_sample))\n",
    "\n",
    "print(y[indices])\n",
    "np.random.shuffle(indices)\n",
    "A = A[indices]\n",
    "y = y[indices]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "noise = torch.randn(A.shape[0], 3072)\n",
    "whitened_noise, trans = zca_whitening(noise)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "beta_range = np.logspace(-1, 3, 5)\n",
    "losses = np.zeros((len(beta_range), 2))\n",
    "for i in range(len(beta_range)):\n",
    "    print('beta', beta_range[i])\n",
    "    losses[i, :] = compare_generator_whitened(whitened_noise, A, beta_range[i], lr=1e-2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(8, 4))\n",
    "\n",
    "plt.loglog(beta_range, losses[:, 0], 'x', markersize=10, label='Copositive relaxation')\n",
    "plt.loglog(beta_range, losses[:, 1], '+', markersize=10, label='Nonconvex SGD solution')\n",
    "\n",
    "plt.title('CIFAR-10 image generation')\n",
    "plt.xlabel('beta')\n",
    "plt.ylabel('Loss')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
