{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import pdb\n",
    "import scipy\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib\n",
    "import time\n",
    "import cvxpy as cp\n",
    "import pickle\n",
    "import torch.functional as F\n",
    "\n",
    "from lib.dataset_utils import *\n",
    "from lib.solvers import solve_feas"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(1)\n",
    "torch.manual_seed(1)\n",
    "\n",
    "(x_train, y_train), (x_valid, y_valid), (x_test, y_test) = load_mnist_all(\n",
    "    data_dir='/data', val_size=0.1, shuffle=True, seed=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "params = {\n",
    "    'dim': 784,\n",
    "    'device': 'cpu',\n",
    "    'max_proj_iters': 1000,\n",
    "    'tol': 1e-5,\n",
    "    'tol_abs': 1e-5,\n",
    "    'tol_rel': 1e-3,\n",
    "    'early_stop': True,\n",
    "    'rho': 1,\n",
    "#     'clip': (0, 1),\n",
    "    'clip': None,\n",
    "    'dtype': torch.float32,\n",
    "    'step_size': 5e-4,\n",
    "    'num_partitions': 3,\n",
    "    'upperbound': np.inf,\n",
    "    'check_obj_steps': 100000\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_polytope(points, idx_intr_point, params):\n",
    "    \"\"\"\n",
    "    Compute polytope Ax <= b of the Voronoi cell of \n",
    "    points[idx_intr_point].\n",
    "    \"\"\"\n",
    "    A = torch.zeros((points.size(0), points.size(1)), \n",
    "                    dtype=params['dtype'])\n",
    "    b = torch.zeros(A.size(0), dtype=params['dtype'])\n",
    "    # for some reason, processing in a loop rather than matrix form\n",
    "    # gives more precise solutions.\n",
    "    for i in range(A.size(0)):\n",
    "        A[i] = F.normalize(points[i] - points[idx_intr_point], 2, 0)\n",
    "        b[i] = (A[i] @ (points[i] + points[idx_intr_point])) / 2\n",
    "    A = torch.cat([A[:idx_intr_point], A[idx_intr_point + 1:]], dim=0)\n",
    "    b = torch.cat([b[:idx_intr_point], b[idx_intr_point + 1:]], dim=0)\n",
    "    return A.to(params['device']), b.to(params['device'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(1)\n",
    "torch.manual_seed(1)\n",
    "\n",
    "# num_constraints = 10000\n",
    "# A = torch.zeros((num_constraints, 784), dtype=params['dtype']).normal_()\n",
    "# A = F.normalize(A, 2, 1)\n",
    "# A = A.to(params['device'])\n",
    "# b = torch.ones(num_constraints, device=params['device'], dtype=params['dtype'])\n",
    "# x_hat = torch.ones(784, device=params['device'], dtype=params['dtype'])\n",
    "\n",
    "# box = torch.eye(784, device=params['device'], dtype=params['dtype'])\n",
    "# A = torch.cat([box, - box], dim=0)\n",
    "# x_hat = torch.ones(784, device=params['device'], dtype=params['dtype']) + 1\n",
    "# b = torch.cat([torch.ones_like(x_hat), torch.zeros_like(x_hat)], dim=0)\n",
    "\n",
    "A, b = get_polytope(x_train.view(-1, params['dim']), 0, params)\n",
    "x_hat = torch.ones(params['dim'], \n",
    "                   device=params['device'], \n",
    "                   dtype=params['dtype'])\n",
    "x_hat.uniform_()\n",
    "# x_hat = x_train[1].view(params['dim'])\n",
    "assert np.all((A @ x_train.view(-1, params['dim'])[0].to(params['device']) \n",
    "               <= b).cpu().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# box = torch.eye(784, device=params['device'], dtype=params['dtype'])\n",
    "# A = torch.cat([A, box, - box], dim=0)\n",
    "# b = torch.cat([b, torch.ones_like(x_hat), torch.zeros_like(x_hat)], dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "AAT = A @ A.T\n",
    "b_hat = A @ x_hat - b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(42651, device='cuda:0')"
      ]
     },
     "execution_count": 81,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "((A @ x_hat - b) > 0).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "7.232408761978149\n"
     ]
    }
   ],
   "source": [
    "# params['method'] = 'dual_ascent'\n",
    "# params['method'] = 'gca'\n",
    "# params['method'] = 'parallel_gca'\n",
    "params['method'] = 'dykstra'\n",
    "# params['method'] = 'cvx'\n",
    "\n",
    "from lib.projection import *\n",
    "start = time.time()\n",
    "# x = proj_polytope(x_hat, A, AAT, b, params)\n",
    "x = proj_polytope(x_hat, A, None, b, params)\n",
    "print(time.time() - start)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([2.3281e-02, 7.3590e-03, 6.5850e-03, 9.6170e-02, 1.5340e-01, 2.6087e-01,\n",
      "        1.8810e-02, 5.1887e-02, 9.9162e-02, 1.6780e-01, 4.5843e-02, 4.9195e-02,\n",
      "        3.0815e-02, 1.1751e-02, 2.5996e-03, 1.0693e-03, 3.6686e-02, 1.0081e-01,\n",
      "        5.0745e-02, 1.2917e-01, 6.1694e-02, 8.2414e-03, 3.8188e-02, 2.1438e-02,\n",
      "        1.6334e-01, 5.3193e-02, 1.0666e-01, 1.9802e-02, 2.2411e-01, 8.5125e-03,\n",
      "        1.2500e-02, 7.9642e-02, 1.2558e-01, 1.8706e-02, 3.4617e-02, 4.2405e-02,\n",
      "        6.9020e-02, 3.2228e-02, 5.9142e-03, 1.7935e-02, 3.1402e-03, 5.5309e-02,\n",
      "        3.5913e-02, 3.4989e-02, 8.9275e-02, 1.2681e-01, 2.6863e-01, 1.6228e-02,\n",
      "        8.0599e-02, 9.6752e-02, 5.4533e-02, 1.2133e-01, 1.9626e-01, 4.7637e-02,\n",
      "        4.1484e-02, 6.5536e-02, 3.9083e-02, 4.3780e-02, 1.1664e-01, 1.0410e-02,\n",
      "        4.4007e-02, 6.0897e-03, 6.7472e-03, 2.0986e-02, 2.3079e-02, 3.3366e-01,\n",
      "        4.9620e-02, 3.0160e-03, 3.5011e-01, 3.3004e-03, 6.1066e-02, 5.4264e-02,\n",
      "        1.3067e-02, 1.5809e-01, 9.1283e-02, 9.5057e-02, 7.2694e-02, 5.1973e-02,\n",
      "        1.5539e-01, 8.8408e-02, 3.5765e-02, 1.4650e-02, 3.3511e-02, 1.3510e-01,\n",
      "        8.5631e-02, 1.2383e-01, 6.2618e-02, 1.3485e-01, 5.8535e-03, 1.2380e-02,\n",
      "        3.1158e-02, 8.5663e-02, 3.4150e-01, 4.4699e-02, 1.3703e-01, 7.8502e-02,\n",
      "        4.9812e-02, 8.4143e-02, 1.2638e-01, 2.4052e-02, 1.5513e-01, 7.8112e-02,\n",
      "        2.7319e-02, 2.9675e-01, 7.3828e-02, 5.6210e-02, 1.7821e-01, 5.3564e-03,\n",
      "        2.7745e-02, 1.6265e-01, 6.3763e-02, 4.7609e-03, 1.1787e-01, 6.7399e-03,\n",
      "        6.0694e-02, 3.1771e-02, 1.3999e-01, 2.7742e-03, 2.1642e-02, 3.2068e-02,\n",
      "        4.4107e-02, 7.7227e-02, 2.7432e-02, 1.8235e-01, 2.4696e-01, 2.2043e-02,\n",
      "        2.7461e-02, 9.4220e-02, 2.9408e-01, 2.6855e-01, 3.1580e-02, 7.2278e-02,\n",
      "        1.9428e-01, 5.8600e-02, 4.8406e-02, 3.7668e-02, 9.7543e-02, 5.1237e-02,\n",
      "        5.5970e-02, 6.6257e-02, 3.0630e-03, 3.7697e-02, 5.2653e-02, 5.9077e-02,\n",
      "        2.4060e-02, 2.6779e-02, 3.5877e-02, 1.3295e-01, 2.2706e-03, 2.4666e-04,\n",
      "        7.7628e-02, 8.1310e-02, 2.6496e-01, 1.4724e-02, 2.2173e-02, 1.4293e-01,\n",
      "        2.1297e-01, 1.0025e-01, 2.5067e-01, 5.3141e-02, 1.4426e-01, 8.9486e-02,\n",
      "        3.6558e-02, 5.8918e-02, 1.9124e-02, 1.0884e-01, 3.0108e-02, 2.1591e-01,\n",
      "        2.9712e-02, 5.8306e-02, 7.2462e-02, 7.7182e-02, 3.8598e-02, 2.4847e-02,\n",
      "        7.3276e-03, 5.0430e-03, 3.2427e-03, 7.5269e-04, 1.3648e-01, 1.3520e-01,\n",
      "        1.4249e-01, 1.2085e-01, 1.5577e-02, 5.9155e-02, 2.7650e-02, 1.0982e-01,\n",
      "        9.5927e-03, 4.6684e-02, 7.7093e-02, 1.3230e-01, 7.0427e-02, 1.4821e-01,\n",
      "        7.1853e-02, 2.0523e-03, 1.7245e-01, 2.0707e-01, 4.5325e-02, 1.3197e-01,\n",
      "        1.1772e-01, 3.4413e-01, 3.8394e-03, 4.7244e-02, 1.1616e-02, 1.1438e-02,\n",
      "        1.0730e-01, 5.5940e-02, 8.7126e-03, 1.7049e-02, 4.3470e-01, 1.9476e-02,\n",
      "        5.4249e-02, 1.1437e-01, 9.6485e-02, 1.1481e-01, 1.6468e-01, 1.8959e-01,\n",
      "        6.3286e-02, 8.6783e-02, 4.8687e-04, 8.1047e-02, 1.3447e-04, 1.9430e-02,\n",
      "        2.1490e-01, 2.8318e-02, 2.8697e-02, 2.4377e-01, 3.3919e-02, 1.1021e-01,\n",
      "        1.9872e-02, 2.4694e-01, 1.0253e-02, 3.2906e-02, 5.7430e-03, 3.2499e-02,\n",
      "        1.3150e-01, 5.9276e-02, 3.1269e-02, 8.3426e-02, 3.6977e-02, 4.7795e-02,\n",
      "        1.1053e-02, 9.7877e-02, 9.1285e-02, 7.4106e-02, 9.5390e-03, 8.3304e-02,\n",
      "        2.4350e-03, 9.1568e-02, 2.2632e-02, 8.2056e-02, 6.5059e-03, 5.4264e-02,\n",
      "        1.5525e-02, 2.6292e-03, 5.9862e-02, 3.6520e-02, 4.8346e-02, 4.2765e-03,\n",
      "        3.1204e-02, 5.1130e-02, 2.2423e-03, 1.8248e-01, 9.0884e-02, 3.7920e-02,\n",
      "        8.6193e-02, 2.6896e-02, 7.8675e-02, 9.6972e-03, 1.8966e-02, 1.8173e-02,\n",
      "        2.0656e-01, 1.3555e-01, 6.4132e-02, 5.6224e-02, 2.5009e-01, 1.8994e-01,\n",
      "        2.0011e-01, 1.2108e-01, 6.0219e-02, 3.0882e-02, 9.5112e-03, 9.5705e-02,\n",
      "        7.7082e-02, 3.9115e-02, 6.4629e-02, 3.5267e-02, 5.4377e-02, 1.4222e-01,\n",
      "        1.0230e-02, 5.0721e-02, 4.4520e-03, 4.4561e-02, 2.3993e-02, 2.5398e-01,\n",
      "        4.2922e-03, 4.8909e-02, 2.3104e-01, 1.2139e-02, 8.1217e-03, 4.4853e-02,\n",
      "        7.2214e-02, 4.4817e-03, 1.6013e-01, 2.0770e-01, 5.7420e-02, 1.1254e-02,\n",
      "        1.4190e-02, 5.5429e-02, 9.3120e-02, 6.7459e-02, 3.4002e-01, 9.0275e-02,\n",
      "        6.9785e-02, 2.2442e-02, 2.6567e-02, 5.6986e-02, 1.8700e-02, 1.6890e-02,\n",
      "        2.0831e-01, 2.3919e-01, 1.4443e-01, 5.0539e-02, 9.3502e-02, 2.1952e-01,\n",
      "        1.5335e-02, 8.9219e-02, 7.6424e-03, 9.4152e-02, 1.3357e-01, 5.0219e-02,\n",
      "        3.1610e-02, 2.5588e-02, 1.5009e-02, 4.9494e-02, 4.1312e-02, 4.9098e-02,\n",
      "        5.4086e-02, 1.2580e-01, 1.3150e-02, 1.3770e-01, 3.2850e-02, 4.5287e-02,\n",
      "        6.2196e-02, 2.9936e-03, 7.6076e-02, 5.6378e-02, 1.2354e-02, 2.4643e-02,\n",
      "        1.0607e-02, 3.5292e-01, 6.8783e-02, 1.2397e-01, 2.5270e-02, 1.0539e-01,\n",
      "        1.6482e-01, 8.2890e-02, 7.5935e-02, 3.8497e-02, 5.7991e-02, 4.4370e-02,\n",
      "        4.6770e-01, 1.0239e-01, 1.5488e-01, 1.2685e-02, 6.8073e-02, 2.0391e-02,\n",
      "        2.5109e-01, 2.3426e-01, 1.3944e-01, 1.4329e-01, 6.6641e-02, 1.0782e-02,\n",
      "        8.4214e-02, 3.7308e-03, 2.6584e-01, 7.7184e-02, 1.4849e-01, 1.4112e-01,\n",
      "        1.0433e-03, 1.3126e-01, 5.7960e-04, 1.1563e-02, 3.0827e-02, 3.3990e-03,\n",
      "        2.3424e-02, 1.4469e-01, 5.5744e-02, 6.5648e-02, 1.3205e-01, 1.1032e-01,\n",
      "        7.4550e-02, 1.8420e-01, 2.0309e-01, 1.3082e-01, 1.9804e-01, 8.6466e-02,\n",
      "        1.1066e-01, 1.0323e-01, 6.5542e-02, 1.5591e-02, 4.3997e-02, 2.0905e-02,\n",
      "        7.6557e-02, 2.0527e-01, 2.3679e-01, 6.0502e-02, 4.7188e-03, 5.6650e-02,\n",
      "        8.9524e-02, 1.1759e-03, 9.7394e-05, 2.5037e-02, 1.9382e-02, 8.8399e-02,\n",
      "        4.6991e-02, 6.5810e-02, 1.3220e-01, 2.9584e-02, 5.5080e-02, 2.2215e-01,\n",
      "        1.9989e-01, 1.0422e-01, 1.7230e-01, 7.3856e-02, 4.4482e-02, 1.0938e-01,\n",
      "        4.2759e-02, 7.5381e-02, 1.6647e-02, 4.0389e-03, 5.9573e-02, 6.5940e-02,\n",
      "        8.0132e-02, 7.6921e-02, 7.0405e-02, 9.6821e-02, 4.7778e-02, 1.5207e-01,\n",
      "        2.4551e-01, 3.2279e-02, 7.8125e-02, 3.7421e-02, 7.8675e-02, 8.2766e-02,\n",
      "        1.3193e-01, 4.3005e-02, 7.7722e-02, 1.6318e-01, 5.2284e-02, 1.9655e-01,\n",
      "        1.5465e-01, 7.5050e-02, 1.9931e-02, 2.8555e-02, 8.4184e-02, 4.0128e-02,\n",
      "        2.4034e-02, 4.6150e-02, 2.3282e-01, 3.7384e-02, 4.8137e-02, 1.6982e-02,\n",
      "        8.7888e-02, 3.0277e-03, 1.1948e-01, 8.0276e-02, 3.1070e-02, 3.5972e-02,\n",
      "        5.7663e-01, 5.6591e-02, 2.4702e-01, 3.3074e-02, 1.8679e-01, 3.0346e-01,\n",
      "        2.7704e-01, 1.1577e-01, 3.3140e-02, 9.4256e-02, 5.1174e-02, 1.3042e-01,\n",
      "        1.7049e-01, 1.5300e-01, 8.3712e-02, 5.4957e-02, 4.0054e-05, 3.5010e-02,\n",
      "        1.4897e-01, 1.1096e-02, 6.9928e-03, 2.0591e-01, 7.9504e-03, 1.5144e-01,\n",
      "        7.2165e-02, 4.7413e-02, 1.7024e-01, 1.5497e-01, 2.4760e-03, 4.1796e-02,\n",
      "        1.5931e-01, 7.4070e-02, 6.3966e-02, 2.5508e-01, 9.2344e-03, 1.3130e-01,\n",
      "        4.5714e-02, 5.1266e-04, 3.1427e-02], device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "res = A @ x - b\n",
    "print(res[res > 0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(1.6708, device='cuda:0')"
      ]
     },
     "execution_count": 90,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "((x - x_hat) ** 2).sum()\n",
    "# ((x - x_hat.numpy()) ** 2).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "8.259012460708618\n",
      "True\n",
      "8.473605871200562\n"
     ]
    }
   ],
   "source": [
    "start = time.time()\n",
    "out = solve_feas(torch.cat([A, b.unsqueeze(1)], dim=1).cpu().numpy(), 2)\n",
    "print(out)\n",
    "print(time.time() - start)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Runtime on full MNIST (x_hat: uniform random)\n",
    "- dual_ascent\n",
    "  - cpu: way too long (step_size 1e-1, tol 1e-5)\n",
    "  - cuda: s (step_size 1e-1, tol 1e-5)\n",
    "- gca\n",
    "  - cpu: 0.5-1s (tol 1e-5)\n",
    "  - cuda: s (tol 1e-5)\n",
    "  - very fast if direct projection to one facet is the solution.\n",
    "- parallel_gca\n",
    "  - cpu: s (num_partitions 10, tol 1e-5).\n",
    "  - cuda: s (tol 1e-5).\n",
    "- admm\n",
    "  - cpu: \n",
    "  - cuda:\n",
    "- cvx (MOSEK):  \n",
    "- cvx (default): \n",
    "\n",
    "- Solving Farkas takes too long."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
