{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cvxpy as cp\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# libraries \n",
    "import numpy as np\n",
    "import cvxpy as cp\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import time\n",
    "import scipy\n",
    "from scipy.sparse.linalg import LinearOperator\n",
    "import torch\n",
    "import sklearn.linear_model\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from torch.autograd import Variable\n",
    "import torch.nn as nn\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision.datasets as datasets\n",
    "import geotorch\n",
    "sns.set()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def nonneg_pca(M, X):\n",
    "    n = X.shape[0]\n",
    "    d = X.shape[1]\n",
    "    \n",
    "    vals, vecs = np.linalg.eig(M)\n",
    "    idx = np.argmax(vals)\n",
    "    \n",
    "    u1 = vecs[:, idx]\n",
    "    \n",
    "    if np.all(X @ u1>= 0):\n",
    "        return u1, vals[idx]\n",
    "    elif np.all(X @ u1<= 0):\n",
    "        return -u1, vals[idx]\n",
    "        \n",
    "    elif d == 2:\n",
    "        C = []\n",
    "        for i in range(n):\n",
    "            ci = np.array([-X[i,1 ], X[i, 0]])/(np.linalg.norm(X[i, :])+1e-12)\n",
    "            if np.all(X @ ci>= 0):\n",
    "                C.append(ci)\n",
    "            elif np.all(X @ ci<= 0):\n",
    "                C.append(-ci)\n",
    "                \n",
    "        if len(C) == 0:\n",
    "#             print('infeasible!')\n",
    "            return np.zeros(d), 0.0\n",
    "        \n",
    "        quads = np.array([C[i].T @ M @ C[i] for i in range(len(C))])\n",
    "        idx = np.argmax(quads)\n",
    "        return C[idx], quads[idx]\n",
    "    \n",
    "    C = []\n",
    "    for i in range(n):\n",
    "        j = min([l for l in range(d) if X[i, l]!= 0])  \n",
    "        \n",
    "        top = np.hstack((np.eye(j), np.zeros((j, d-j-1))))\n",
    "        mid = -1/(X[i, j]) * np.delete(X[i], j)\n",
    "        bottom = np.hstack((np.zeros((d-j-1, j)), np.eye(d-j-1)))\n",
    "        \n",
    "        H = np.vstack((top, mid, bottom))\n",
    "        \n",
    "        U_h, S_h, V_h = np.linalg.svd(H, full_matrices=False)\n",
    "        M_curr = U_h.T @ M @ U_h\n",
    "        X_curr = np.delete(X, i, 0) @ U_h\n",
    "        \n",
    "        c_i_hat, _ = nonneg_pca(M_curr, X_curr)\n",
    "        \n",
    "        c_i = U_h @ c_i_hat\n",
    "        C.append(c_i)\n",
    "        \n",
    "    quads = np.array([ C[i].T @ M @ C[i] for i in range(len(C))])\n",
    "    idx = np.argmax(quads)\n",
    "    return C[idx], quads[idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def datagen(ns, nc):\n",
    "    X = np.zeros((nc*ns, 2))\n",
    "    y = np.zeros((nc*ns))\n",
    "    for c in range(nc):\n",
    "        r = np.linspace(0,1,ns)/2\n",
    "        t = np.linspace(c*4, (c+1)*4, ns) + 0.15*np.random.randn(1,ns)\n",
    "        X[c*ns:(c+1)*ns, 0]  = r * np.sin(t)\n",
    "        X[c*ns:(c+1)*ns, 1]  = r * np.cos(t)\n",
    "        y[c*ns:(c+1)*ns] = c\n",
    "        \n",
    "    return X, y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ns = 20\n",
    "X, y = datagen(ns, 3)\n",
    "print(len(X))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(8, 8))\n",
    "plt.scatter(X[:ns, 0], X[:ns, 1], label='class 0')\n",
    "plt.scatter(X[ns:2*ns, 0], X[ns:2*ns, 1], label='class 1')\n",
    "plt.scatter(X[2*ns:3*ns, 0], X[2*ns:3*ns, 1], label='class 2')\n",
    "# plt.scatter(X[3*ns:, 0], X[3*ns:, 1], label='class 3')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# function definitions\n",
    "def check_if_already_exists(element_list, element):\n",
    "    # check if element exists in element_list\n",
    "    # where element is a numpy array\n",
    "    return list(element) in element_list\n",
    "\n",
    "def generate_sign_patterns(A, P, verbose=False): \n",
    "    # generate sign patterns\n",
    "    n, d = A.shape\n",
    "    unique_sign_pattern_list = []  # sign patterns\n",
    "    u_vector_list = []             # random vectors used to generate the sign paterns\n",
    "\n",
    "    for i in range(P): \n",
    "        # obtain a sign pattern\n",
    "        u = np.random.normal(0, 1, (d,1)) # sample u\n",
    "        sampled_sign_pattern = (np.matmul(A, u) >= 0)[:,0]\n",
    "\n",
    "        # check whether that sign pattern has already been used\n",
    "        if not check_if_already_exists(unique_sign_pattern_list, sampled_sign_pattern):\n",
    "            unique_sign_pattern_list.append(list(sampled_sign_pattern))\n",
    "            u_vector_list.append(u)\n",
    "            \n",
    "            if verbose and len(u_vector_list)%10 == 0:\n",
    "                print(i, 'generated', len(u_vector_list), 'unique sign patterns')\n",
    "\n",
    "    if verbose:\n",
    "        print(\"Number of unique sign patterns generated: \" + str(len(unique_sign_pattern_list)))\n",
    "    return unique_sign_pattern_list, u_vector_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PrepareData(Dataset):\n",
    "    def __init__(self, X, y):\n",
    "        if not torch.is_tensor(X):\n",
    "            self.X = torch.from_numpy(X).float()\n",
    "        else:\n",
    "            self.X = X\n",
    "            \n",
    "        if not torch.is_tensor(y):\n",
    "            self.y = torch.from_numpy(y).float()\n",
    "        else:\n",
    "            self.y = y\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.X)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.X[idx], self.y[idx]\n",
    "\n",
    "def one_hot(labels, num_classes=10):\n",
    "    y = torch.eye(num_classes) \n",
    "    return y[labels.long()]\n",
    "\n",
    "def identity(labels, num_classses):\n",
    "    return labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sign_pattern_list, u_vector_list = generate_sign_patterns(X, 100000, verbose=True)\n",
    "sign_patterns = np.array(sign_pattern_list).astype(np.int32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def frank_wolfe_nn(X, y, sign_patterns, beta, t, epochs=1, lr =0.5, \n",
    "                   use_cvxpy = False, print_freq=100, return_times=False):\n",
    "    n = X.shape[0]\n",
    "    d = X.shape[1]\n",
    "    c = y.shape[1]\n",
    "    P = len(sign_patterns)\n",
    "    V = [np.zeros((d, c)) for i in range(2*P)]\n",
    "    const = 2\n",
    "    times = np.zeros(epochs)\n",
    "    losses = np.zeros(epochs)\n",
    "    \n",
    "    for ep in range(epochs):\n",
    "        lr = const/(const+ep)**(1)\n",
    "        \n",
    "        R_k = y - sum([np.multiply(sign_patterns[i], X) @ V [i] for i in range(P)])\n",
    "        loss = 1/2 * np.linalg.norm(R_k)**2 + beta * t\n",
    "        \n",
    "        losses[ep] = loss\n",
    "        times[ep] = time.time()\n",
    "        \n",
    "        if ep % print_freq == 0:\n",
    "            print(ep, loss)\n",
    "     \n",
    "        max_nonneg_pca = np.zeros(P)\n",
    "        max_v_pca = []\n",
    "        \n",
    "        for i in range(P):\n",
    "            mask = np.multiply(sign_patterns[i], X)\n",
    "            R = 2*mask - X\n",
    "            M = mask.T @ R_k @ R_k.T @ mask\n",
    "\n",
    "            u, value = nonneg_pca(M, R)\n",
    "            u = np.real(u)\n",
    "            value = np.real(value)\n",
    "#             print(u)\n",
    "\n",
    "            g = R_k.T @ mask @ u\n",
    "            if np.linalg.norm(g) != 0:\n",
    "                g = g / np.linalg.norm(g)\n",
    "                \n",
    "            max_nonneg_pca[i] = value\n",
    "            max_v_pca.append(np.outer(u, g))\n",
    "            \n",
    "        \n",
    "        best = np.argmax(max_nonneg_pca)\n",
    "        v_chosen = max_v_pca[best]\n",
    "#         print(max_nonneg_pca)\n",
    "#         print('best', best)\n",
    "        \n",
    "        other_indices = np.delete(np.arange(P), best)\n",
    "        \n",
    "        if use_cvxpy:\n",
    "            step_size = cp.Variable(1)\n",
    "            V_next = [(1-step_size)*V[i] + step_size*t*v_chosen if i==best else V[i]*(1-step_size) for i in range(P)]\n",
    "            res_next = y - sum([sign_patterns[i] * X @ V_next [i] for i in range(P)])\n",
    "            obj = cp.Minimize(cp.norm(res_next)**2)\n",
    "            constr = [step_size >= 0, step_size <=1]\n",
    "            prob = cp.Problem(obj, constr)\n",
    "            prob.solve()\n",
    "            \n",
    "    #         print('chosen step size', step_size.value)\n",
    "            V = [V_next[i].value for i in range(P)]\n",
    "            if ep % print_freq == 0:\n",
    "                print(step_size.value)\n",
    "        else:\n",
    "            V = [(1-lr)*V[i] + lr*t*v_chosen if i==best else V[i]*(1-lr) for i in range(P)]\n",
    "        \n",
    "#         together = [lambd[i]*V[i]*lr + t*v_chosen*(1-lr) if i==best else lambd[i]*(1-lr)*V[i] for i in range(P)]\n",
    "#         lambd = [np.linalg.norm(V[i], 'nuc') for i in range(P)]\n",
    "#         V = [together[i]/(lambd[i] + 1e-12) for i in range(P)]\n",
    "\n",
    "#         lambd[best] = (1-lr)*lambd[best] + lr*t\n",
    "#         lambd[other_indices] = (1-lr)*lambd[other_indices]\n",
    "#         V = [(1-lr)*V[i] + lr*t*v_chosen if i==best else V[i]*(1-lr) for i in range(P)]\n",
    "\n",
    "\n",
    "    R_k = y - sum([np.multiply(sign_patterns[i], X) @ V [i] for i in range(P)])\n",
    "    loss = 1/2 * np.linalg.norm(R_k)**2 + beta * np.sum([np.linalg.norm(V[i], 'nuc') for i in range(P)])\n",
    "    \n",
    "\n",
    "    if return_times:\n",
    "        return losses, times\n",
    "    else:\n",
    "        return loss, V\n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class FCNetwork(nn.Module):\n",
    "    def __init__(self, H, num_classes=10, input_dim=784):\n",
    "        self.num_classes = num_classes\n",
    "        super(FCNetwork, self).__init__()\n",
    "        self.layer1 = nn.Sequential(nn.Linear(input_dim, H, bias=False), nn.ReLU())\n",
    "        self.layer2 = nn.Linear(H, num_classes, bias=False)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x.reshape(x.size(0), -1)\n",
    "        out = self.layer2(self.layer1(x))\n",
    "        return out\n",
    "\n",
    "\n",
    "def loss_func_primal(yhat, y, model, beta):\n",
    "    loss = 0.5 * torch.norm(yhat - y)**2\n",
    "    \n",
    "    ## l2 norm on first layer weights, l1 squared norm on second layer\n",
    "    for layer, p in enumerate(model.parameters()):\n",
    "        loss += beta/2 * torch.norm(p)**2\n",
    "    \n",
    "    return loss\n",
    "\n",
    "# solves nonconvex problem\n",
    "def sgd_solver_pytorch(ds, num_epochs, num_neurons, beta, \n",
    "                         learning_rate, batch_size, verbose=False, \n",
    "                        num_classes=10, D_in=784, train_len=60000, hot_fn=one_hot):\n",
    "    \n",
    "    device = torch.device('cuda')\n",
    "    # D_in is input dimension, H is hidden dimension, D_out is output dimension.\n",
    "    H, D_out = num_neurons, num_classes\n",
    "    # create the model\n",
    "    model = FCNetwork(H, D_out, D_in).to(device)\n",
    "    \n",
    "    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.95)\n",
    "    \n",
    "    # arrays for saving the loss and accuracy    \n",
    "    losses = np.zeros((int(num_epochs)))\n",
    "    accs = np.zeros(losses.shape)\n",
    "    times = np.zeros((losses.shape[0]+1))\n",
    "    times[0] = time.time()\n",
    "    \n",
    "    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,\n",
    "                                                           verbose=verbose,\n",
    "                                                           factor=0.5,\n",
    "                                                          eps=1e-12,\n",
    "                                                          patience=100)\n",
    "\n",
    "    iter_no = 0\n",
    "    for i in range(num_epochs):\n",
    "        for ix, (_x, _y) in enumerate(ds):\n",
    "            #=========make input differentiable=======================\n",
    "            _x = Variable(_x).to(device)\n",
    "            _y = Variable(_y).to(device)\n",
    "            \n",
    "            #========forward pass=====================================\n",
    "            yhat = model(_x).float()\n",
    "            \n",
    "            loss = loss_func_primal(yhat, hot_fn(_y, num_classes).to(device), model, beta)\n",
    "#             correct = torch.eq(torch.argmax(yhat, dim=1), torch.squeeze(_y)).float().sum()/len(_y)\n",
    "            \n",
    "            optimizer.zero_grad() # zero the gradients on each pass before the update\n",
    "            loss.backward() # backpropagate the loss through the model\n",
    "            optimizer.step() # update the gradients w.r.t the loss\n",
    "\n",
    "            losses[iter_no] = loss.item() # loss on the minibatch\n",
    "#             accs[iter_no] = correct\n",
    "        \n",
    "            iter_no += 1\n",
    "            times[iter_no] = time.time()\n",
    "        \n",
    "        if i % 1000 == 0:\n",
    "            print(\"Epoch [{}/{}], loss: {}, acc: {}\".format(i, num_epochs,\n",
    "                    losses[iter_no-1], accs[iter_no-1]))\n",
    "        \n",
    "        scheduler.step(losses[iter_no-1])\n",
    "            \n",
    "    return losses, accs, times, model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compare_across_beta(X, y, beta_range, sign_patterns, num_classes = 3, \n",
    "                        hot_fn = one_hot, init_lr = None, primal_epochs = 8001, dual_epochs=None):\n",
    "    \n",
    "    losses = np.zeros((len(beta_range), 2))\n",
    "    \n",
    "    for i, beta in enumerate(beta_range):\n",
    "    \n",
    "        print('beta', beta)\n",
    "        data = PrepareData(X, y)\n",
    "        batch_size = X.shape[0]\n",
    "        num_neurons = 1000\n",
    "        learning_rate = (2e-4)\n",
    "        \n",
    "        if init_lr is None:\n",
    "        \n",
    "            if beta < .01:\n",
    "                learning_rate = 1e-2\n",
    "            elif beta < 1:\n",
    "                learning_rate = 1e-3\n",
    "        \n",
    "        else:\n",
    "            learning_rate = init_lr\n",
    "        train_loader = torch.utils.data.DataLoader(\n",
    "            data, batch_size=batch_size, shuffle=False,\n",
    "            pin_memory=True, sampler=None)\n",
    "\n",
    "        ls, _, _, model = sgd_solver_pytorch(train_loader, primal_epochs, num_neurons, beta, \n",
    "                             learning_rate, batch_size, D_in=X.shape[1], verbose=True, \n",
    "                             num_classes = num_classes, train_len=X.shape[0], hot_fn=hot_fn)\n",
    "        \n",
    "        losses[i, 0] = ls[-1]\n",
    "        t = 0\n",
    "        for layer, p in enumerate(model.parameters()):\n",
    "             t+= torch.norm(p)**2 /2\n",
    "        \n",
    "        t = t.cpu().item()\n",
    "        \n",
    "        print('t', t)\n",
    "           \n",
    "        if dual_epochs is None:\n",
    "            if beta < .01:\n",
    "                dual_epochs = 15001\n",
    "            else:\n",
    "                dual_epochs = 1001\n",
    "        y_hot = hot_fn(torch.Tensor(y), num_classes).numpy()\n",
    "        \n",
    "        loss, V_vec = frank_wolfe_nn(X, y_hot, np.expand_dims(sign_patterns, 2), beta, t, epochs=dual_epochs)\n",
    "        losses[i, 1] = loss\n",
    "        \n",
    "    return losses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "losses = compare_across_beta(X, y, np.logspace(-2, 2, 5), sign_patterns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(8, 4))\n",
    "\n",
    "plt.semilogx(np.logspace(-2, 2, 5), losses[:, 1], 'x', markersize=10, label='Algorithm 1')\n",
    "plt.semilogx(np.logspace(-2, 2, 5), losses[:, 0], '+', markersize=10, label='Nonconvex SGD solution')\n",
    "\n",
    "plt.title('Spiral Classification')\n",
    "plt.xlabel('beta')\n",
    "plt.ylabel('Loss')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# solve the convex NN program corresponding to convex relaxation in data whitened case. \n",
    "def train_constrained_cvxpy(A, y_hot, sign_patterns, beta, eps=1e-4):\n",
    "    n = len(A)\n",
    "    P = len(sign_patterns)\n",
    "    d = A.shape[1]\n",
    "    y_hot = y_hot.data.numpy()\n",
    "    U_list = [cp.Variable((d, d), PSD=True) for i in range(P)]\n",
    "    masked_A = [sign_patterns[i]* A for i in range(P)]\n",
    "    preds = sum([masked_A[i] @ U_list[i] @ masked_A[i].T for i in range(P)])\n",
    "    \n",
    "    objective = cp.Minimize( 0.5*cp.matrix_frac(y_hot, torch.eye(n) + 2*preds)+\\\n",
    "                            beta**2 * sum([cp.trace(U_list[i]) for i in range(P)]))\n",
    "    \n",
    "    masked_A_2 = [(2*sign_patterns[i] - 1)* A for i in range(P)]\n",
    "    constraints = [masked_A_2[i] @ U_list[i] @ masked_A_2[i].T >= 0 for i in range(P)]\n",
    "    problem = cp.Problem(objective, constraints)\n",
    "    problem.solve(solver='SCS', eps=eps, max_iters=20000, verbose=True)\n",
    "    print('copositive relaxation loss', objective.value)\n",
    "    return objective.value, U_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "losses_cvx_relaxation = []\n",
    "y_hot = one_hot(torch.Tensor(y), 3)\n",
    "for beta in np.logspace(-2, 2, 5):\n",
    "    losses_cvx_relaxation.append(train_constrained_cvxpy(X, y_hot, np.expand_dims(sign_patterns, 2), beta, eps=1e-2)[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(8, 4))\n",
    "\n",
    "plt.semilogx(np.logspace(-2, 2, 5), losses[:, 1], 'x', markersize=10, label='Algorithm 1')\n",
    "plt.semilogx(np.logspace(-2, 2, 5), losses[:, 0], '+', markersize=10, label='Non-convex SGD solution')\n",
    "plt.semilogx(np.logspace(-2, 2, 5), losses_cvx_relaxation, '.', markersize=10, label='Copositive relaxation')\n",
    "\n",
    "\n",
    "# plt.title('Spiral Classification')\n",
    "plt.xlabel('beta')\n",
    "plt.ylabel('Loss')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 25\n",
    "d = 2\n",
    "random_dataset_2 = np.random.randn(n, d)\n",
    "gen_model = FCNetwork(100, 5, d)\n",
    "labels_2 = gen_model(torch.Tensor(random_dataset_2)).detach().numpy()\n",
    "sign_pattern_list, u_vector_list = generate_sign_patterns(random_dataset_2, 100000, verbose=True)\n",
    "random_sign_patterns_2 = np.array(sign_pattern_list).astype(np.int32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels_2.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_runs = 5\n",
    "losses_all_3neuron = []\n",
    "\n",
    "beta = 0.01\n",
    "data = PrepareData(random_dataset_2, labels_2)\n",
    "batch_size = random_dataset_2.shape[0]\n",
    "num_neurons = 3\n",
    "learning_rate = 1e-3\n",
    "\n",
    "for run in range(num_runs):\n",
    "    print('run', run)\n",
    "    train_loader = torch.utils.data.DataLoader(\n",
    "        data, batch_size=batch_size, shuffle=False,\n",
    "        pin_memory=True, sampler=None)\n",
    "\n",
    "    ls, _, _, model = sgd_solver_pytorch(train_loader, 10000, num_neurons, beta, \n",
    "                         learning_rate, batch_size, D_in=random_dataset_2.shape[1], verbose=True, \n",
    "                         num_classes = 5, train_len=random_dataset_2.shape[0], hot_fn=identity)\n",
    "    \n",
    "    losses_all_3neuron.append(ls)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_runs = 5\n",
    "losses_all_10neuron = []\n",
    "\n",
    "num_neurons = 10\n",
    "learning_rate = 1e-3\n",
    "\n",
    "for run in range(num_runs):\n",
    "    print('run', run)\n",
    "\n",
    "    ls, _, _, model = sgd_solver_pytorch(train_loader, 20000, num_neurons, beta, \n",
    "                         learning_rate, batch_size, D_in=random_dataset_2.shape[1], verbose=True, \n",
    "                         num_classes = 5, train_len=random_dataset_2.shape[0], hot_fn=identity)\n",
    "    \n",
    "    losses_all_10neuron.append(ls)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_runs = 5\n",
    "losses_all_50neuron = []\n",
    "\n",
    "num_neurons = 50\n",
    "learning_rate = 1e-3\n",
    "\n",
    "for run in range(num_runs):\n",
    "    print('run', run)\n",
    "\n",
    "    ls, _, _, model = sgd_solver_pytorch(train_loader, 20000, num_neurons, beta, \n",
    "                         learning_rate, batch_size, D_in=random_dataset_2.shape[1], verbose=True, \n",
    "                         num_classes = 5, train_len=random_dataset_2.shape[0], hot_fn=identity)\n",
    "    \n",
    "    losses_all_50neuron.append(ls)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_runs = 1\n",
    "losses_all_1000neuron = []\n",
    "\n",
    "num_neurons = 1000\n",
    "learning_rate = 1e-3\n",
    "\n",
    "for run in range(num_runs):\n",
    "    print('run', run)\n",
    "\n",
    "    ls, _, _, model = sgd_solver_pytorch(train_loader, 50000, num_neurons, beta, \n",
    "                         learning_rate, batch_size, D_in=random_dataset_2.shape[1], verbose=True, \n",
    "                         num_classes = 5, train_len=random_dataset_2.shape[0], hot_fn=identity)\n",
    "    \n",
    "    losses_all_1000neuron.append(ls)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = 0\n",
    "for layer, p in enumerate(model.parameters()):\n",
    "     t+= torch.norm(p)**2 /2\n",
    "\n",
    "t = t.cpu().item()\n",
    "\n",
    "print('t', t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t=1.495\n",
    "losses_fw, times_fw = frank_wolfe_nn(random_dataset_2, labels_2, np.expand_dims(random_sign_patterns_2, 2), \n",
    "                                beta, t, epochs=30000, use_cvxpy=False, print_freq=200, return_times=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_, _, times_3, _ = sgd_solver_pytorch(train_loader, 10000, 3, beta, \n",
    "                         learning_rate, batch_size, D_in=random_dataset_2.shape[1], verbose=True, \n",
    "                         num_classes = 5, train_len=random_dataset_2.shape[0], hot_fn=identity)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(8, 4))\n",
    "epochs_3 = np.arange(10000)\n",
    "\n",
    "\n",
    "for i in range(5):\n",
    "    plt.plot(epochs_3, losses_all_3neuron[i], label='SGD, Trial ' + str(i+1))\n",
    "plt.hlines(losses_fw[-1],0, len(epochs_3), linestyles='dashed', label='Optimal solution (Algorithm 1)')\n",
    "\n",
    "# plt.title('SGD may not reach the global optimum (m=3)')\n",
    "plt.xlabel('Epoch')\n",
    "plt.ylabel('Loss')\n",
    "plt.legend()\n",
    "plt.ylim(0, 0.25)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_, _, times_10, _ = sgd_solver_pytorch(train_loader, 20000, 10, beta, \n",
    "                         learning_rate, batch_size, D_in=random_dataset_2.shape[1], verbose=True, \n",
    "                         num_classes = 5, train_len=random_dataset_2.shape[0], hot_fn=identity)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(8, 4))\n",
    "epochs_10 = np.arange(20000)\n",
    "\n",
    "for i in range(5):\n",
    "    plt.plot(epochs_10, losses_all_10neuron[i], label='SGD, Trial ' + str(i+1))\n",
    "plt.hlines(losses_fw[-1],0, len(epochs_10), linestyles='dashed', label='Optimal solution (Algorithm 1)')\n",
    "\n",
    "# plt.title('SGD may not reach the global optimum (m=10)')\n",
    "plt.xlabel('Epoch')\n",
    "plt.ylabel('Loss')\n",
    "plt.legend()\n",
    "plt.ylim(0.015, 0.02)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(8, 4))\n",
    "epochs_50 = np.arange(20000)\n",
    "\n",
    "for i in range(5):\n",
    "    plt.plot(epochs_50, losses_all_50neuron[i], label='SGD, Trial ' + str(i+1))\n",
    "plt.hlines(losses_fw[-1],0, len(epochs_50), linestyles='dashed', label='Optimal solution (Algorithm 1)')\n",
    "\n",
    "# plt.title('SGD may not reach the global optimum (m=50)')\n",
    "plt.xlabel('Epoch')\n",
    "plt.ylabel('Loss')\n",
    "plt.legend()\n",
    "plt.ylim(0.015, 0.02)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
