{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# libraries \n",
    "import numpy as np\n",
    "import cvxpy as cp\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import time\n",
    "import torch\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from torch.autograd import Variable\n",
    "import torch.nn as nn\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision.datasets as datasets\n",
    "import math\n",
    "sns.set()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Spiral Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def datagen(ns, nc):\n",
    "    X = np.zeros((nc*ns, 2))\n",
    "    y = np.zeros((nc*ns))\n",
    "    for c in range(nc):\n",
    "        r = np.linspace(0,1,ns)/2\n",
    "        t = np.linspace(c*4, (c+1)*4, ns) + 0.15*np.random.randn(1,ns)\n",
    "        X[c*ns:(c+1)*ns, 0]  = r * np.sin(t)\n",
    "        X[c*ns:(c+1)*ns, 1]  = r * np.cos(t)\n",
    "        y[c*ns:(c+1)*ns] = c\n",
    "        \n",
    "    return X, y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "ns = 50\n",
    "nc = 3\n",
    "X, y = datagen(ns, nc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "sns.set(font_scale = 1.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize=(8, 8))\n",
    "\n",
    "for j in range(nc):\n",
    "    plt.scatter(X[ns*j:ns*(j+1), 0], X[ns*j:ns*(j+1), 1], label='class ' +str(j))\n",
    "\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Code for Training Gated ReLU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def generate_sign_patterns(A, P, verbose=False): \n",
    "    # generate sign patterns\n",
    "    n, d = A.shape\n",
    "    unique_sign_pattern_list = []  # sign patterns\n",
    "    u_vector_list = []             # random vectors used to generate the sign paterns\n",
    "\n",
    "    for i in range(P): \n",
    "        # obtain a sign pattern\n",
    "        u = np.random.normal(0, 1, (d,1)) # sample u\n",
    "        sampled_sign_pattern = (np.matmul(A, u) >= 0)[:,0]\n",
    "        unique_sign_pattern_list.append(list(sampled_sign_pattern))\n",
    "        u_vector_list.append(u)\n",
    "\n",
    "    if verbose:\n",
    "        print(\"Number of unique sign patterns generated: \" + str(len(unique_sign_pattern_list)))\n",
    "    return unique_sign_pattern_list, u_vector_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class PrepareData(Dataset):\n",
    "    def __init__(self, X, y):\n",
    "        if not torch.is_tensor(X):\n",
    "            self.X = torch.from_numpy(X).float()\n",
    "        else:\n",
    "            self.X = X\n",
    "            \n",
    "        if not torch.is_tensor(y):\n",
    "            self.y = torch.from_numpy(y).float()\n",
    "        else:\n",
    "            self.y = y\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.X)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.X[idx], self.y[idx]\n",
    "    \n",
    "def one_hot_signed(labels, num_classes=10):\n",
    "    y = torch.eye(num_classes) \n",
    "    return 2*y[labels.long()] - 1\n",
    "\n",
    "def one_hot(labels, num_classes=10):\n",
    "    y = torch.eye(num_classes) \n",
    "    return y[labels.long()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class convex_net(nn.Module):\n",
    "    def __init__(self, n, d,c, u_vectors, linear=False):\n",
    "        super(convex_net, self).__init__()\n",
    "        self.u_vectors = u_vectors\n",
    "        self.P = u_vectors.shape[0]\n",
    "        self.d = d\n",
    "        self.c = c\n",
    "        self.Z = nn.Parameter(torch.randn(self.P, self.d, c), requires_grad=True)\n",
    "        self.linear = linear\n",
    "        \n",
    "    def forward(self, x):\n",
    "        with torch.no_grad():\n",
    "            if not self.linear:\n",
    "                sign_patterns = (x @ self.u_vectors.t() >= 0).double() # n x P\n",
    "            else:\n",
    "                sign_patterns = torch.ones((x.shape[0], self.u_vectors.shape[0])).double()\n",
    "        \n",
    "        return torch.einsum('nd, pdc, np -> nc', x, self.Z, sign_patterns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def train_convex_cvxpy(A, y_hot, u_vectors, beta, linear=False):\n",
    "    P = u_vectors.shape[0]\n",
    "    n = A.shape[0]\n",
    "    d = A.shape[1]\n",
    "    c = y_hot.shape[1]\n",
    "    Z = [cp.Variable((d, c)) for i in range(P)]\n",
    "    if not linear:\n",
    "        sign_patterns = (A  @ u_vectors.T >= 0).astype(np.float32) # n x P\n",
    "    else:\n",
    "        sign_patterns = np.ones((n, P)).astype(np.float32)\n",
    "    preds = sum([(np.expand_dims(sign_patterns[:, i], 1) * A) @ Z[i] for i in range(P)])\n",
    "    objective = cp.Minimize( 1/(n*c)*cp.norm(preds - y_hot, 'fro')**2 + beta*sum([cp.norm(Z[i], 'nuc') for i in range(P)]))\n",
    "    problem = cp.Problem(objective, [])\n",
    "    problem.solve(verbose=True, solver=cp.MOSEK)\n",
    "    \n",
    "    return objective.value, Z"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class bm_net(nn.Module):\n",
    "    def __init__(self, n, d, c, u_vectors, m=1, linear=False):\n",
    "        super(bm_net, self).__init__()\n",
    "        self.u_vectors = u_vectors\n",
    "        self.P = u_vectors.shape[0]\n",
    "        self.d = d\n",
    "        self.c = c\n",
    "        self.m = 1\n",
    "        self.linear = linear\n",
    "        self.U = nn.Parameter(torch.randn(self.P, self.d, self.m)/math.sqrt(self.d), requires_grad=True)\n",
    "        self.V = nn.Parameter(torch.randn(self.P, self.m, self.c)/self.m, requires_grad=True)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        with torch.no_grad():\n",
    "            if not self.linear:\n",
    "                sign_patterns = (x @ self.u_vectors.t() >= 0).double() # n x P \n",
    "            else:\n",
    "                sign_patterns = torch.ones((x.shape[0], self.u_vectors.shape[0])).double()\n",
    "        \n",
    "        return torch.einsum('nd, pdm, pmc, np -> nc', x, self.U, self.V, sign_patterns)\n",
    "\n",
    "def train_bm(A, y_hot, u_vectors, num_epochs, beta, lr, m=1, linear=False):\n",
    "    model =  bm_net(A.shape[0], A.shape[1], y_hot.shape[1], u_vectors.double(), m, linear=linear).double()\n",
    "    P = u_vectors.shape[0]\n",
    "    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=beta)\n",
    "    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,\n",
    "                                                           verbose=False,\n",
    "                                                           factor=0.9,\n",
    "                                                          eps=1e-20,\n",
    "                                                          patience=350)\n",
    "    \n",
    "    criterion = torch.nn.MSELoss()\n",
    "    data = PrepareData(A.double(), y_hot.double())\n",
    "    batch_size = A.shape[0]\n",
    "    train_loader = torch.utils.data.DataLoader(\n",
    "        data, batch_size=batch_size, shuffle=False,\n",
    "        pin_memory=True, sampler=None)\n",
    "    losses_epoch = []\n",
    "        \n",
    "    for epoch in range(num_epochs):\n",
    "        loss_epoch = 0\n",
    "        for ix, (x, y) in enumerate(train_loader):\n",
    "            pred = model(x)\n",
    "\n",
    "            loss = criterion(pred, y) \n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            \n",
    "            # for tracking regularized loss curves\n",
    "            loss_incl_reg = loss + beta/2 * (torch.norm(model.U)**2 + torch.norm(model.V)**2)\n",
    "            loss_epoch += loss_incl_reg.item()\n",
    "            \n",
    "        loss_epoch /= len(train_loader)\n",
    "        if epoch % 1000 == 0:\n",
    "            print(epoch, loss_epoch)\n",
    "        scheduler.step(loss_epoch)\n",
    "        losses_epoch.append(loss_epoch)\n",
    "        \n",
    "    print('checking dual qualification constraint')\n",
    "    convex_model = convex_net(A.shape[0], A.shape[1], y_hot.shape[1], u_vectors.double(), linear=linear).double()\n",
    "    convex_model.Z.data = model.U @ model.V\n",
    "    \n",
    "    loss_epoch = 0\n",
    "    for ix, (x, y) in enumerate(train_loader):\n",
    "        pred = convex_model(x)\n",
    "\n",
    "        loss = criterion(pred, y)*len(x)/len(A)\n",
    "        loss.backward()\n",
    "        loss_incl_reg = loss + beta * sum([torch.norm(convex_model.Z[i]) for i in range(P)])*len(x)/len(A)\n",
    "        loss_epoch += loss_incl_reg.item()\n",
    "        \n",
    "    print('final loss', loss_epoch)\n",
    "    \n",
    "    grad = convex_model.Z.grad\n",
    "    op_norms = [np.linalg.norm(grad[i], 2) for i in range(P)]\n",
    "    max_op_norm = max(op_norms)\n",
    "    constr = max_op_norm/beta - 1\n",
    "    print('dual constraint', constr)\n",
    "#     print(op_norms)\n",
    "    \n",
    "    return convex_model, losses_epoch, constr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def compare_convex_bm(A, y_hot, u_vectors, beta, lr, m_range=[1, 2, 5], linear=False, num_epochs=25000):\n",
    "    cvxpy_sol =  train_convex_cvxpy(A.numpy(), y_hot.numpy(), u_vectors, beta, linear=linear)\n",
    "    trained_convex_model = convex_net(A.shape[0], A.shape[1], y_hot.shape[1], torch.from_numpy(u_vectors).double(), linear=linear).double()\n",
    "    for i in range(u_vectors.shape[0]):\n",
    "        trained_convex_model.Z.data[i]= torch.from_numpy(cvxpy_sol[1][i].value).double()\n",
    "\n",
    "    A, y_hot = A.double(), y_hot.double()\n",
    "    P = u_vectors.shape[0]\n",
    "    criterion = torch.nn.MSELoss()\n",
    "    predfull = trained_convex_model(A)\n",
    "    convex_loss = criterion(predfull, y_hot) +beta *sum([torch.norm(trained_convex_model.Z[i], 'nuc') for i in range(u_vectors.shape[0])])\n",
    "    cvx_loss = convex_loss.item()\n",
    "    print('convex loss', cvx_loss)\n",
    "    \n",
    "    bm_losses_all = []\n",
    "    constraints_all = []\n",
    "    \n",
    "    for m in m_range:\n",
    "        print('m', m)\n",
    "        convex_bm_model, bm_losses, constr = train_bm(A, y_hot, \n",
    "                      torch.from_numpy(u_vectors).float(),\n",
    "                      num_epochs, beta, lr, m=m, linear=linear)\n",
    "        bm_losses_all.append(bm_losses)\n",
    "        constraints_all.append(constr)\n",
    "        \n",
    "    return cvx_loss, bm_losses_all, constraints_all\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def compute_qualifications_over_n_beta(ns_range=[1, 5, 15, 25, 50], \n",
    "                                       beta_range=[1e-4, 1e-3, 1e-2, 1e-1],\n",
    "                                       linear=False, lr=1.0):\n",
    "    nc = 3\n",
    "    results_all = []\n",
    "    params_all = []\n",
    "    \n",
    "    for n in ns_range:\n",
    "        X, y = datagen(ns, nc)\n",
    "        if not linear:\n",
    "            sign_pattern_list, u_vector_list = generate_sign_patterns(X, 100, verbose=True)\n",
    "        else:\n",
    "            sign_pattern_list, u_vector_list = generate_sign_patterns(X, 1, verbose=True)\n",
    "        sign_patterns = np.array(sign_pattern_list).astype(np.int32)\n",
    "        u_vectors = np.array(u_vector_list)[:, :, 0]\n",
    "        P = u_vectors.shape[0]\n",
    "        \n",
    "        for beta in beta_range:\n",
    "            print('comparing', n*3, beta)\n",
    "            result_curr = compare_convex_bm(torch.from_numpy(X).float(), one_hot(torch.Tensor(y), nc), \n",
    "                              u_vectors, beta, lr=lr, num_epochs=20000, linear=linear)\n",
    "            results_all.append(result_curr)\n",
    "            params_all.append((n*3, beta))\n",
    "    \n",
    "    \n",
    "    return results_all, params_all"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Gated ReLU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "results_all_0, params_all_0 = compute_qualifications_over_n_beta()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "beta_values = [1e-4, 1e-3, 1e-2, 1e-1]\n",
    "n_values = 3*np.array([1, 5, 15, 25, 50])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "all_qc_fb = []\n",
    "all_fl_fb = []\n",
    "all_cvx_fb = []\n",
    "all_n_fb = []\n",
    "\n",
    "for beta in beta_values:\n",
    "    qualification_constraints_fixed_beta = []\n",
    "    final_losses_fixed_beta = []\n",
    "    cvx_losses_fixed_beta = []\n",
    "    n_values_fixed_beta = []\n",
    "    for i, result in enumerate(results_all_0):\n",
    "        params = params_all_0[i]\n",
    "        if params[1] != beta:\n",
    "            continue\n",
    "        \n",
    "        n_values_fixed_beta.append(params[0])\n",
    "        cvx_losses_fixed_beta.append(result[0])\n",
    "        final_losses_fixed_beta.append([result[1][j][-1] for j in range(len(result[1]))])\n",
    "        qualification_constraints_fixed_beta.append(result[2])\n",
    "        \n",
    "    all_qc_fb.append(qualification_constraints_fixed_beta)\n",
    "    all_fl_fb.append(final_losses_fixed_beta)\n",
    "    all_cvx_fb.append(cvx_losses_fixed_beta)\n",
    "    all_n_fb.append(n_values_fixed_beta)\n",
    "    \n",
    "all_qc_fb = np.array(all_qc_fb)\n",
    "all_fl_fb = np.array(all_fl_fb)\n",
    "all_cvx_fb = np.array(all_cvx_fb)\n",
    "all_n_fb = np.array(all_n_fb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "results_all_02, params_all_02 = compute_qualifications_over_n_beta()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "results_all_03, params_all_03 = compute_qualifications_over_n_beta()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "all_qc_fb2 = []\n",
    "all_fl_fb2 = []\n",
    "all_cvx_fb2 = []\n",
    "all_n_fb2 = []\n",
    "\n",
    "for beta in beta_values:\n",
    "    qualification_constraints_fixed_beta = []\n",
    "    final_losses_fixed_beta = []\n",
    "    cvx_losses_fixed_beta = []\n",
    "    n_values_fixed_beta = []\n",
    "    for i, result in enumerate(results_all_02):\n",
    "        params = params_all_02[i]\n",
    "        if params[1] != beta:\n",
    "            continue\n",
    "        \n",
    "        n_values_fixed_beta.append(params[0])\n",
    "        cvx_losses_fixed_beta.append(result[0])\n",
    "        final_losses_fixed_beta.append([result[1][j][-1] for j in range(len(result[1]))])\n",
    "        qualification_constraints_fixed_beta.append(result[2])\n",
    "        \n",
    "    all_qc_fb2.append(qualification_constraints_fixed_beta)\n",
    "    all_fl_fb2.append(final_losses_fixed_beta)\n",
    "    all_cvx_fb2.append(cvx_losses_fixed_beta)\n",
    "    all_n_fb2.append(n_values_fixed_beta)\n",
    "    \n",
    "all_qc_fb2 = np.array(all_qc_fb2)\n",
    "all_fl_fb2 = np.array(all_fl_fb2)\n",
    "all_cvx_fb2 = np.array(all_cvx_fb2)\n",
    "all_n_fb2 = np.array(all_n_fb2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "all_qc_fb3 = []\n",
    "all_fl_fb3 = []\n",
    "all_cvx_fb3 = []\n",
    "all_n_fb3 = []\n",
    "\n",
    "for beta in beta_values:\n",
    "    qualification_constraints_fixed_beta = []\n",
    "    final_losses_fixed_beta = []\n",
    "    cvx_losses_fixed_beta = []\n",
    "    n_values_fixed_beta = []\n",
    "    for i, result in enumerate(results_all_03):\n",
    "        params = params_all_03[i]\n",
    "        if params[1] != beta:\n",
    "            continue\n",
    "        \n",
    "        n_values_fixed_beta.append(params[0])\n",
    "        cvx_losses_fixed_beta.append(result[0])\n",
    "        final_losses_fixed_beta.append([result[1][j][-1] for j in range(len(result[1]))])\n",
    "        qualification_constraints_fixed_beta.append(result[2])\n",
    "        \n",
    "    all_qc_fb3.append(qualification_constraints_fixed_beta)\n",
    "    all_fl_fb3.append(final_losses_fixed_beta)\n",
    "    all_cvx_fb3.append(cvx_losses_fixed_beta)\n",
    "    all_n_fb3.append(n_values_fixed_beta)\n",
    "    \n",
    "all_qc_fb3 = np.array(all_qc_fb3)\n",
    "all_fl_fb3 = np.array(all_fl_fb3)\n",
    "all_cvx_fb3 = np.array(all_cvx_fb3)\n",
    "all_n_fb3 = np.array(all_n_fb3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "all_qc_fb_all = np.maximum(np.stack((all_qc_fb, all_qc_fb2, all_qc_fb3)), 1e-14)\n",
    "all_fl_fb_all = np.stack((all_fl_fb, all_fl_fb2, all_fl_fb3))\n",
    "all_cvx_fb_all = np.stack((all_cvx_fb, all_cvx_fb2, all_cvx_fb3))\n",
    "all_n_fb_all = np.stack((all_n_fb, all_n_fb2, all_n_fb3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "all_qc_fb_mean = np.mean(all_qc_fb_all, 0)\n",
    "all_qc_fb_upper = np.max(all_qc_fb_all, 0) - all_qc_fb_mean\n",
    "all_qc_fb_lower = all_qc_fb_mean - np.min(all_qc_fb_all, 0)\n",
    "all_qc_fb_errs = np.stack((all_qc_fb_lower, all_qc_fb_upper))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "all_actual_fb_all = np.maximum((all_fl_fb_all - np.expand_dims(all_cvx_fb_all, 3))/np.expand_dims(all_cvx_fb_all, 3), 1e-14)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "all_actual_fb_mean =  np.mean(all_actual_fb_all, 0)\n",
    "all_actual_fb_upper = np.max(all_actual_fb_all, 0) - all_actual_fb_mean\n",
    "all_actual_fb_lower = all_actual_fb_mean - np.min(all_actual_fb_all, 0)\n",
    "all_actual_fb_errs = np.stack((all_actual_fb_lower, all_actual_fb_upper))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "for k in range(len(beta_values)):\n",
    "    print('beta', beta_values[k])\n",
    "    fig, ax = plt.subplots(figsize=(15, 12))\n",
    "\n",
    "    plot = ax.errorbar(n_values, y=all_qc_fb_mean[k,:,0] , yerr=all_qc_fb_errs[:, k, :, 0], label='Bound, m=1',\n",
    "                      capsize=10)\n",
    "    plot = ax.errorbar(n_values, y=all_qc_fb_mean[k,:,1] , yerr=all_qc_fb_errs[:, k, :, 1], label='Bound, m=2',\n",
    "                      capsize=10)\n",
    "    plot = ax.errorbar(n_values, y=all_qc_fb_mean[k,:,2] , yerr=all_qc_fb_errs[:, k, :, 2], label='Bound, m=5',\n",
    "                      capsize=10)\n",
    "    \n",
    "    plot = ax.errorbar(n_values, y=all_actual_fb_mean[k,:,0] , yerr=all_actual_fb_errs[:, k, :, 0], label='Actual, m=1',\n",
    "                      capsize=10, linestyle='dashed')\n",
    "    plot = ax.errorbar(n_values, y=all_actual_fb_mean[k,:,1] , yerr=all_actual_fb_errs[:, k, :, 1], label='Actual, m=2',\n",
    "                      capsize=10, linestyle='dashed')\n",
    "    plot = ax.errorbar(n_values, y=all_actual_fb_mean[k,:,2] , yerr=all_actual_fb_errs[:, k, :, 2], label='Actual, m=5',\n",
    "                      capsize=10, linestyle='dashed')\n",
    "\n",
    "    plt.xticks(all_n_fb[k])\n",
    "\n",
    "    ax.set_xlabel('n')\n",
    "    ax.set_ylabel('Relative Optimality Gap')\n",
    "    ax.set_yscale('log')\n",
    "    plt.legend()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "for k in range(len(n_values)):\n",
    "    print('n', n_values[k])\n",
    "    fig, ax = plt.subplots(figsize=(8, 5))\n",
    "    \n",
    "    plot = ax.errorbar(beta_values, y=all_qc_fb_mean[:, k,0] , yerr=all_qc_fb_errs[:, :, k, 0], label='Bound, m=1',\n",
    "                      capsize=10)\n",
    "    plot = ax.errorbar(beta_values, y=all_qc_fb_mean[:, k,1] , yerr=all_qc_fb_errs[:, :, k, 1], label='Bound, m=2',\n",
    "                      capsize=10)\n",
    "    plot = ax.errorbar(beta_values, y=all_qc_fb_mean[:, k,2] , yerr=all_qc_fb_errs[:, :, k, 2], label='Bound, m=5',\n",
    "                      capsize=10)\n",
    "    \n",
    "    plot = ax.errorbar(beta_values, y=all_actual_fb_mean[:, k,0] , yerr=all_actual_fb_errs[:, :, k, 0], label='Actual, m=1',\n",
    "                      capsize=10, linestyle='dashed')\n",
    "    plot = ax.errorbar(beta_values, y=all_actual_fb_mean[:, k,1] , yerr=all_actual_fb_errs[:, :, k, 1], label='Actual, m=2',\n",
    "                      capsize=10, linestyle='dashed')\n",
    "    plot = ax.errorbar(beta_values, y=all_actual_fb_mean[:, k,2] , yerr=all_actual_fb_errs[:, :, k, 2], label='Actual, m=5',\n",
    "                      capsize=10, linestyle='dashed')\n",
    "\n",
    "    plt.xticks(beta_values)\n",
    "\n",
    "    ax.set_xlabel('beta')\n",
    "    ax.set_ylabel('Relative Optimality Gap')\n",
    "    ax.set_yscale('log')\n",
    "    ax.set_xscale('log')\n",
    "    plt.legend()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Linear Activation Case"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "linear_results_all_0, linear_params_all_0 = compute_qualifications_over_n_beta(linear=True)\n",
    "linear_results_all_02, linear_params_all_02 = compute_qualifications_over_n_beta(linear=True)\n",
    "linear_results_all_03, linear_params_all_03 = compute_qualifications_over_n_beta(linear=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "linear_all_qc_fb = []\n",
    "linear_all_fl_fb = []\n",
    "linear_all_cvx_fb = []\n",
    "linear_all_n_fb = []\n",
    "\n",
    "for beta in beta_values:\n",
    "    qualification_constraints_fixed_beta = []\n",
    "    final_losses_fixed_beta = []\n",
    "    cvx_losses_fixed_beta = []\n",
    "    n_values_fixed_beta = []\n",
    "    for i, result in enumerate(linear_results_all_0):\n",
    "        params = linear_params_all_0[i]\n",
    "        if params[1] != beta:\n",
    "            continue\n",
    "        \n",
    "        n_values_fixed_beta.append(params[0])\n",
    "        cvx_losses_fixed_beta.append(result[0])\n",
    "        final_losses_fixed_beta.append([result[1][j][-1] for j in range(len(result[1]))])\n",
    "        qualification_constraints_fixed_beta.append(result[2])\n",
    "        \n",
    "    linear_all_qc_fb.append(qualification_constraints_fixed_beta)\n",
    "    linear_all_fl_fb.append(final_losses_fixed_beta)\n",
    "    linear_all_cvx_fb.append(cvx_losses_fixed_beta)\n",
    "    linear_all_n_fb.append(n_values_fixed_beta)\n",
    "    \n",
    "linear_all_qc_fb = np.array(linear_all_qc_fb)\n",
    "linear_all_fl_fb = np.array(linear_all_fl_fb)\n",
    "linear_all_cvx_fb = np.array(linear_all_cvx_fb)\n",
    "linear_all_n_fb = np.array(linear_all_n_fb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "linear_all_qc_fb2 = []\n",
    "linear_all_fl_fb2 = []\n",
    "linear_all_cvx_fb2 = []\n",
    "linear_all_n_fb2 = []\n",
    "\n",
    "for beta in beta_values:\n",
    "    qualification_constraints_fixed_beta = []\n",
    "    final_losses_fixed_beta = []\n",
    "    cvx_losses_fixed_beta = []\n",
    "    n_values_fixed_beta = []\n",
    "    for i, result in enumerate(linear_results_all_02):\n",
    "        params = linear_params_all_02[i]\n",
    "        if params[1] != beta:\n",
    "            continue\n",
    "        \n",
    "        n_values_fixed_beta.append(params[0])\n",
    "        cvx_losses_fixed_beta.append(result[0])\n",
    "        final_losses_fixed_beta.append([result[1][j][-1] for j in range(len(result[1]))])\n",
    "        qualification_constraints_fixed_beta.append(result[2])\n",
    "        \n",
    "    linear_all_qc_fb2.append(qualification_constraints_fixed_beta)\n",
    "    linear_all_fl_fb2.append(final_losses_fixed_beta)\n",
    "    linear_all_cvx_fb2.append(cvx_losses_fixed_beta)\n",
    "    linear_all_n_fb2.append(n_values_fixed_beta)\n",
    "    \n",
    "linear_all_qc_fb2 = np.array(linear_all_qc_fb2)\n",
    "linear_all_fl_fb2 = np.array(linear_all_fl_fb2)\n",
    "linear_all_cvx_fb2 = np.array(linear_all_cvx_fb2)\n",
    "linear_all_n_fb2 = np.array(linear_all_n_fb2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "linear_all_qc_fb3 = []\n",
    "linear_all_fl_fb3 = []\n",
    "linear_all_cvx_fb3 = []\n",
    "linear_all_n_fb3 = []\n",
    "\n",
    "for beta in beta_values:\n",
    "    qualification_constraints_fixed_beta = []\n",
    "    final_losses_fixed_beta = []\n",
    "    cvx_losses_fixed_beta = []\n",
    "    n_values_fixed_beta = []\n",
    "    for i, result in enumerate(linear_results_all_03):\n",
    "        params = linear_params_all_03[i]\n",
    "        if params[1] != beta:\n",
    "            continue\n",
    "        \n",
    "        n_values_fixed_beta.append(params[0])\n",
    "        cvx_losses_fixed_beta.append(result[0])\n",
    "        final_losses_fixed_beta.append([result[1][j][-1] for j in range(len(result[1]))])\n",
    "        qualification_constraints_fixed_beta.append(result[2])\n",
    "        \n",
    "    linear_all_qc_fb3.append(qualification_constraints_fixed_beta)\n",
    "    linear_all_fl_fb3.append(final_losses_fixed_beta)\n",
    "    linear_all_cvx_fb3.append(cvx_losses_fixed_beta)\n",
    "    linear_all_n_fb3.append(n_values_fixed_beta)\n",
    "    \n",
    "linear_all_qc_fb3 = np.array(linear_all_qc_fb3)\n",
    "linear_all_fl_fb3 = np.array(linear_all_fl_fb3)\n",
    "linear_all_cvx_fb3 = np.array(linear_all_cvx_fb3)\n",
    "linear_all_n_fb3 = np.array(linear_all_n_fb3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "linear_all_qc_fb_all = np.maximum(np.stack((linear_all_qc_fb, linear_all_qc_fb2, linear_all_qc_fb3)), 1e-14)\n",
    "linear_all_fl_fb_all = np.stack((linear_all_fl_fb, linear_all_fl_fb2, linear_all_fl_fb3))\n",
    "linear_all_cvx_fb_all = np.stack((linear_all_cvx_fb, linear_all_cvx_fb2, linear_all_cvx_fb3))\n",
    "linear_all_n_fb_all = np.stack((linear_all_n_fb, linear_all_n_fb2, linear_all_n_fb3))\n",
    "\n",
    "linear_all_qc_fb_mean = np.mean(linear_all_qc_fb_all, 0)\n",
    "linear_all_qc_fb_upper = np.max(linear_all_qc_fb_all, 0) - linear_all_qc_fb_mean\n",
    "linear_all_qc_fb_lower = linear_all_qc_fb_mean - np.min(linear_all_qc_fb_all, 0)\n",
    "linear_all_qc_fb_errs = np.stack((linear_all_qc_fb_lower, linear_all_qc_fb_upper))\n",
    "\n",
    "linear_all_actual_fb_all = np.maximum((linear_all_fl_fb_all - np.expand_dims(linear_all_cvx_fb_all, 3))/np.expand_dims(linear_all_cvx_fb_all, 3), 1e-14)\n",
    "\n",
    "linear_all_actual_fb_mean =  np.mean(linear_all_actual_fb_all, 0)\n",
    "linear_all_actual_fb_upper = np.max(linear_all_actual_fb_all, 0) - linear_all_actual_fb_mean\n",
    "linear_all_actual_fb_lower = linear_all_actual_fb_mean - np.min(linear_all_actual_fb_all, 0)\n",
    "linear_all_actual_fb_errs = np.stack((linear_all_actual_fb_lower, linear_all_actual_fb_upper))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "for k in range(len(n_values)):\n",
    "    print('n', n_values[k])\n",
    "    fig, ax = plt.subplots(figsize=(8, 5))\n",
    "\n",
    "    plot = ax.errorbar(beta_values, y=linear_all_qc_fb_mean[:, k,0] , yerr=linear_all_qc_fb_errs[:, :, k, 0], label='Bound, m=1',\n",
    "                      capsize=10)\n",
    "    plot = ax.errorbar(beta_values, y=linear_all_qc_fb_mean[:, k,1] , yerr=linear_all_qc_fb_errs[:, :, k, 1], label='Bound, m=2',\n",
    "                      capsize=10)\n",
    "    plot = ax.errorbar(beta_values, y=linear_all_qc_fb_mean[:, k,2] , yerr=linear_all_qc_fb_errs[:, :, k, 2], label='Bound, m=5',\n",
    "                      capsize=10)\n",
    "    \n",
    "    plot = ax.errorbar(beta_values, y=linear_all_actual_fb_mean[:, k,0] , yerr=linear_all_actual_fb_errs[:, :, k, 0], label='Actual, m=1',\n",
    "                      capsize=10, linestyle='dashed')\n",
    "    plot = ax.errorbar(beta_values, y=linear_all_actual_fb_mean[:, k,1] , yerr=linear_all_actual_fb_errs[:, :, k, 1], label='Actual, m=2',\n",
    "                      capsize=10, linestyle='dashed')\n",
    "    plot = ax.errorbar(beta_values, y=linear_all_actual_fb_mean[:, k,2] , yerr=linear_all_actual_fb_errs[:, :, k, 2], label='Actual, m=5',\n",
    "                      capsize=10, linestyle='dashed')\n",
    "\n",
    "\n",
    "    plt.xticks(beta_values)\n",
    "\n",
    "    ax.set_xlabel('beta')\n",
    "    ax.set_ylabel('Relative Optimality Gap')\n",
    "    ax.set_yscale('log')\n",
    "    ax.set_xscale('log')\n",
    "    plt.legend()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
