{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.spatial import Voronoi, voronoi_plot_2d\n",
    "from scipy.spatial import ConvexHull\n",
    "import scipy\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib\n",
    "import time\n",
    "import cvxpy as cp\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Testing runtime scaling with the number of points and dimensions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dim 4: 0.0506s\n",
      "Dim 5: 0.1401s\n",
      "Dim 6: 0.8386s\n",
      "Dim 7: 5.9499s\n"
     ]
    }
   ],
   "source": [
    "num_points = 200\n",
    "for dim in [4, 5, 6, 7]:\n",
    "    points = np.random.uniform(size=(num_points, dim))\n",
    "    start_time = time.time()\n",
    "    vor = Voronoi(points)\n",
    "    print('Dim %d: %.4fs' % (dim, time.time() - start_time))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Num points 100: 0.2524s\n",
      "Num points 200: 0.8675s\n",
      "Num points 300: 1.8501s\n"
     ]
    }
   ],
   "source": [
    "dim = 6\n",
    "for num_points in [100, 200, 300]:\n",
    "    points = np.random.uniform(size=(num_points, dim))\n",
    "    start_time = time.time()\n",
    "    vor = Voronoi(points)\n",
    "    print('Num points %d: %.4fs' % (num_points, time.time() - start_time))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "--- "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 2019\n",
    "np.random.seed(seed)\n",
    "\n",
    "num_points = 100\n",
    "dim = 2\n",
    "\n",
    "# sample points from unifromly random distribution inside [0, 1] box\n",
    "points = np.random.uniform(size=(num_points, dim))\n",
    "\n",
    "# construct Voronoi diagram\n",
    "vor = Voronoi(points)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dealing with infinite segments\n",
    "\n",
    "`vor.ridge_points` identifies separating hyperplanes that the edges lie on  \n",
    "`vor.rdige_vertices`\n",
    "\n",
    "TODO: Now we skip infinite ridges"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class VorLSH(object):\n",
    "    def __init__(self, vor, params, seed=123):\n",
    "        # set random seed for this LSH\n",
    "        np.random.seed(seed)\n",
    "        self.vor = vor\n",
    "        self.dim = params['dim']\n",
    "        self.num_proj = params['num_proj']\n",
    "        self.bucket_size = params['bucket_size']\n",
    "        self.max_bucket = 2 * np.ceil(\n",
    "            np.sqrt(self.dim) / self.bucket_size).astype(np.int32) + 1\n",
    "        # bucket goes from [- self.offset, self.offset + 1]\n",
    "        self.offset = self.max_bucket // 2\n",
    "\n",
    "        # init random projection lines passing through (0, 0)\n",
    "        w = np.random.randn(self.num_proj, self.dim)\n",
    "        self.w = w / np.linalg.norm(w, axis=1, keepdims=True)\n",
    "\n",
    "        # naive implementation\n",
    "        self.num_edges = len(vor.ridge_vertices)\n",
    "        self.tab = np.zeros((self.num_proj, self.max_bucket, self.num_edges))\n",
    "        for i, vertices in enumerate(vor.ridge_vertices):\n",
    "            # TODO: we skip infinite ones for now\n",
    "            if -1 in vertices:\n",
    "                continue\n",
    "            # get buckets for all vertices and projections\n",
    "            # b has shape (num_proj, len(vertices))\n",
    "            b = self.w @ vor.vertices[vertices].T\n",
    "            b = (b // self.bucket_size).astype(np.int32)\n",
    "            # b should be clipped to [- offset, offset]\n",
    "            b = np.clip(b, - self.offset, self.offset)\n",
    "            for j in range(self.num_proj):\n",
    "                start = b[j].min() + self.offset\n",
    "                stop = b[j].max() + self.offset\n",
    "                self.tab[j, start:stop + 1, i] = 1\n",
    "\n",
    "    def _check_bound(self, b):\n",
    "        return (b >= - self.offset) * (b <= self.offset)\n",
    "    \n",
    "    def ridge_to_vertices(self, indices):\n",
    "        \"\"\"\n",
    "        Take indices of ridges and return the corresponding\n",
    "        sets of vertices that form each ridge\n",
    "        \"\"\"\n",
    "        vertices = []\n",
    "        for idx in indices:\n",
    "            v = self.vor.ridge_vertices[idx]\n",
    "            if -1 in v:\n",
    "                raise ValueError('Something went wrong!')\n",
    "            vertices.append(self.vor.vertices[v])\n",
    "        return vertices\n",
    "\n",
    "    def query(self, x):\n",
    "        out = np.zeros((self.num_proj, self.num_edges))\n",
    "        buckets = ((self.w @ x) // self.bucket_size).astype(np.int32)\n",
    "        if not np.all(self._check_bound(buckets)):\n",
    "            raise ValueError('Bucket is out of bound.')\n",
    "        buckets += self.offset\n",
    "        for i in range(self.num_proj):\n",
    "            out[i] = self.tab[i, buckets[i]]\n",
    "        return out\n",
    "    \n",
    "    def query_nearby(self, x, b):\n",
    "        # count nearby buckets as a hit as well [-b, b]\n",
    "        out = np.zeros((self.num_proj, self.num_edges))\n",
    "        buckets = ((self.w @ x) // self.bucket_size).astype(np.int32)\n",
    "        if not np.all(self._check_bound(buckets)):\n",
    "            raise ValueError('Bucket is out of bound.')\n",
    "        buckets += self.offset\n",
    "        for i in range(self.num_proj):\n",
    "            start = max(buckets[i] - b, 0)\n",
    "            end = buckets[i] + b + 1\n",
    "            out[i] = self.tab[i, start : end].sum(0)\n",
    "        out = (out >= 1).astype(np.int32)\n",
    "        return out\n",
    "\n",
    "    def predict(self, x, b=0, idx=None):\n",
    "        if b == 0:\n",
    "            out = self.query(x)\n",
    "        elif b > 0:\n",
    "            out = self.query_nearby(x, b)\n",
    "        else:\n",
    "            raise ValueError('Invalid value for b (>= 0).')\n",
    "        # do not break tie, only check if idx is in the top ties\n",
    "        counts = out.sum(0)\n",
    "        indices = np.nonzero(counts.max() == counts)[0]\n",
    "        if idx is not None:\n",
    "            return out, (idx in indices)\n",
    "        else:\n",
    "            return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# find length of projected ridges\n",
    "diam = []\n",
    "w = np.random.randn(100, dim)\n",
    "w = w / np.linalg.norm(w, axis=1, keepdims=True)\n",
    "\n",
    "for i, vertices in enumerate(vor.ridge_vertices):\n",
    "    # TODO: we skip infinite ones for now\n",
    "    if -1 in vertices:\n",
    "        continue\n",
    "    b = w @ vor.vertices[vertices].T\n",
    "    # this can be clip to reflect input domain [0, 1]\n",
    "    b = np.clip(b, - np.sqrt(dim), np.sqrt(dim))\n",
    "    diam.append(b.max(1) - b.min(1))\n",
    "diam = np.concatenate(diam, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "8392"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.sum([-1 not in v for v in vor.ridge_vertices])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average projected ridge length: 0.7699\n",
      "Median projected ridge length: 0.5195\n",
      "Max projected ridge length: 4.8990\n",
      "Min projected ridge length: 0.0000\n"
     ]
    }
   ],
   "source": [
    "print('Average projected ridge length: %.4f' % diam.mean())\n",
    "print('Median projected ridge length: %.4f' % np.median(diam))\n",
    "print('Max projected ridge length: %.4f' % diam.max())\n",
    "print('Min projected ridge length: %.4f' % diam.min())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 581,
   "metadata": {},
   "outputs": [],
   "source": [
    "params = {'dim': dim,\n",
    "          'num_proj': 20,\n",
    "          'bucket_size': 0.1}\n",
    "lsh = VorLSH(vor, params, seed=seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_half_space(a, b):\n",
    "    \"\"\"\n",
    "    Get hyperplane that w.T @ x = c and w points toward a\n",
    "    \"\"\"\n",
    "    w = (b - a)\n",
    "    c = np.dot(w.T, (a + b) / 2)\n",
    "    sign = - np.sign(np.dot(w.T, b) - c)\n",
    "    w = sign * w\n",
    "    c = sign * c\n",
    "    return [w, c]\n",
    "\n",
    "\n",
    "def find_nearest_ridge(x, vor):\n",
    "    \"\"\"\n",
    "    Find the nearest point on one of the ridges of the\n",
    "    cell <x> is in\n",
    "    \"\"\"\n",
    "    dim = len(x)\n",
    "    num_points = len(vor.points)\n",
    "    # find the cell x is in\n",
    "    dist = np.linalg.norm(x - vor.points, axis=1)\n",
    "    near_idx = dist.argmin()\n",
    "\n",
    "    best_sol = 0\n",
    "    best_dist = 1e5\n",
    "\n",
    "    for j in range(num_points):\n",
    "        if j == near_idx:\n",
    "            continue\n",
    "            \n",
    "        # TODO: we skip infite ridge for now\n",
    "        region = vor.regions[vor.point_region[j]]\n",
    "        if -1 in region or not region:\n",
    "            continue\n",
    "\n",
    "        cur_point = vor.points[j]\n",
    "        w_const = np.zeros((num_points, dim))\n",
    "        c_const = np.zeros(num_points)\n",
    "        for i, point in enumerate(vor.points):\n",
    "            w, c = get_half_space(cur_point, point)\n",
    "            w_const[i] = w\n",
    "            c_const[i] = c\n",
    "        # delete the nearest point itself\n",
    "        w_const = np.delete(w_const, j, 0)\n",
    "        c_const = np.delete(c_const, j, 0)\n",
    "\n",
    "        # QP\n",
    "        x_adv = cp.Variable(dim)\n",
    "        constraint = [w_const @ x_adv >= c_const]\n",
    "        prob = cp.Problem(cp.Minimize(cp.norm(x_adv - x)),\n",
    "                          constraint)\n",
    "        prob.solve()\n",
    "        dist = prob.value\n",
    "        if dist < best_dist:\n",
    "            best_dist = dist\n",
    "            best_sol = x_adv.value\n",
    "            \n",
    "    return best_sol\n",
    "\n",
    "\n",
    "def check_sol(x, x_adv, vertices):\n",
    "    \"\"\"\n",
    "    Check if <x_adv> found by the optimization is a\n",
    "    projection of <x> on the ridge formed by <vertices>\n",
    "    \"\"\"\n",
    "    d = x.shape[0]\n",
    "    \n",
    "    # project x onto hyperplane passing through the vertices\n",
    "    # find the hyperplane w using pseudo-inverse (same as least square solution)\n",
    "    w = np.linalg.pinv(vertices) @ np.ones(len(vertices))\n",
    "    # sanity check\n",
    "    if not np.allclose(vertices @ w, np.ones(len(vertices))):\n",
    "        import pdb; pdb.set_trace()\n",
    "\n",
    "    # check if x_adv is on hyperplane w\n",
    "    return np.allclose(w @ x_adv, 1)\n",
    "\n",
    "\n",
    "def check_candidates(x, x_adv, candidates):\n",
    "    \"\"\"\n",
    "    Check if <x_adv> found by the optimization corresponds to \n",
    "    any of the <candidates> returned by the LSH.\n",
    "    Return index of the matched candidate. Otherwise, return None.\n",
    "    \"\"\"\n",
    "    for i, vertices in enumerate(candidates):\n",
    "        if check_sol(x, x_adv, vertices):\n",
    "            return i\n",
    "    return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 583,
   "metadata": {},
   "outputs": [],
   "source": [
    "# x = np.array([0.1, 0.1, 0.1, 0.1])\n",
    "x = np.random.rand(dim)\n",
    "x_adv = find_nearest_ridge(x, vor)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 584,
   "metadata": {},
   "outputs": [],
   "source": [
    "counts = lsh.query(x).sum(0)\n",
    "top_counts = np.sort(np.unique(counts))[::-1]\n",
    "ind_top1 = np.nonzero(top_counts[0] == counts)[0]\n",
    "candidates = lsh.ridge_to_vertices(ind_top1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 585,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "False\n",
      "False\n",
      "False\n",
      "False\n",
      "False\n"
     ]
    }
   ],
   "source": [
    "for vertices in candidates:\n",
    "    print(check_sol(x, x_adv, vertices))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 586,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average hits: 7.33\n",
      "Average max hits: 20.00\n",
      "Accuracy: 0.98\n",
      "Accuracy (top-2): 1.00\n",
      "Accuracy (top-3): 1.00\n",
      "Average number of candidates (top-1): 24.54\n",
      "Average number of candidates (top-2): 43.08\n",
      "Average number of candidates (top-3): 64.02\n"
     ]
    }
   ],
   "source": [
    "num = 100\n",
    "sum_out = 0\n",
    "sum_max = 0\n",
    "num_correct = 0\n",
    "num_top1, num_top2, num_top3 = 0, 0, 0\n",
    "num_correct_top2, num_correct_top3 = 0, 0\n",
    "\n",
    "seed = 2019\n",
    "np.random.seed(seed)\n",
    "\n",
    "for i in range(num):\n",
    "    x = np.random.uniform(size=(dim, ))\n",
    "    x_adv = find_nearest_ridge(x, vor)\n",
    "\n",
    "#     out = lsh.query(x)\n",
    "    out = lsh.query_nearby(x, 1)\n",
    "    counts = out.sum(0)\n",
    "    top_counts = np.sort(np.unique(counts))[::-1]\n",
    "    ind_top1 = np.nonzero(top_counts[0] == counts)[0]\n",
    "    \n",
    "    vert_top1 = lsh.ridge_to_vertices(ind_top1)\n",
    "    match_top1 = check_candidates(x, x_adv, vert_top1)\n",
    "    \n",
    "    sum_out += out\n",
    "    sum_max += counts.max()\n",
    "    num_correct += (match_top1 is not None)\n",
    "    \n",
    "    try:\n",
    "        ind_top2 = np.nonzero(top_counts[1] == counts)[0]\n",
    "        vert_top2 = lsh.ridge_to_vertices(ind_top2)\n",
    "        match_top2 = check_candidates(x, x_adv, vert_top2)\n",
    "    except:\n",
    "        ind_top2 = []\n",
    "        match_top2 = match_top1\n",
    "    try:\n",
    "        ind_top3 = np.nonzero(top_counts[2] == counts)[0]\n",
    "        vert_top3 = lsh.ridge_to_vertices(ind_top3)\n",
    "        match_top3 = check_candidates(x, x_adv, vert_top3)\n",
    "    except:\n",
    "        ind_top3 = []\n",
    "        match_top3 = match_top2\n",
    "\n",
    "    num_correct_top2 += (match_top1 is not None or \n",
    "                         match_top2 is not None)\n",
    "    num_correct_top3 += (match_top1 is not None or \n",
    "                         match_top2 is not None or\n",
    "                         match_top3 is not None)\n",
    "\n",
    "    num_top1 += len(ind_top1)\n",
    "    num_top2 += len(ind_top2)\n",
    "    num_top3 += len(ind_top3)\n",
    "    \n",
    "num_top2 += num_top1\n",
    "num_top3 += num_top2\n",
    "\n",
    "# number of collisions averaged over all projections\n",
    "print('Average hits: %.2f' % (sum_out.sum(0).mean() / num))\n",
    "print('Average max hits: %.2f' % (sum_max / num))\n",
    "print('Accuracy: %.2f' % (num_correct / num))\n",
    "print('Accuracy (top-2): %.2f' % (num_correct_top2 / num))\n",
    "print('Accuracy (top-3): %.2f' % (num_correct_top3 / num))\n",
    "print('Average number of candidates (top-1): %.2f' % (num_top1 / num))\n",
    "print('Average number of candidates (top-2): %.2f' % (num_top2 / num))\n",
    "print('Average number of candidates (top-3): %.2f' % (num_top3 / num))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Comments\n",
    "- Using more projections doesn't always help, or reduce the number of top-1 candidates (top-1 hits could reduce and become an inaccurate measure)\n",
    "- Using \"nearby buckets\" increases accuracy as expected, and it is usually better than increasing bucket size (smaller number of top-1 candidates).\n",
    "- Large number of projections would require larger bucket size and/or larger b\n",
    "- Picking bucket size of 0.1 when average edge length is ~0.5 still feels too large\n",
    "\n",
    "Results (dim = 4)\n",
    "- `(num_proj=5, bucket_size=0.1, b=0)\n",
    "Average hits: 1.27\n",
    "Average max hits: 5.00\n",
    "Accuracy: 0.84\n",
    "Accuracy (top-2): 0.99\n",
    "Accuracy (top-3): 1.00\n",
    "Average number of candidates (top-1): 23.62\n",
    "Average number of candidates (top-2): 94.11\n",
    "Average number of candidates (top-3): 221.08`\n",
    "- `(num_proj=5, bucket_size=0.1, b=1)\n",
    "Average hits: 1.81\n",
    "Average max hits: 5.00\n",
    "Accuracy: 0.99\n",
    "Accuracy (top-2): 1.00\n",
    "Accuracy (top-3): 1.00\n",
    "Average number of candidates (top-1): 73.27\n",
    "Average number of candidates (top-2): 209.98\n",
    "Average number of candidates (top-3): 397.70`\n",
    "- `(num_proj=10, bucket_size=0.1, b=1)\n",
    "Average hits: 3.69\n",
    "Average max hits: 10.00\n",
    "Accuracy: 0.99\n",
    "Accuracy (top-2): 1.00\n",
    "Accuracy (top-3): 1.00\n",
    "Average number of candidates (top-1): 41.38\n",
    "Average number of candidates (top-2): 85.18\n",
    "Average number of candidates (top-3): 148.07`\n",
    "\n",
    "Results (dim = 6)\n",
    "- `(num_proj=10, bucket_size=0.1, b=0)\n",
    "Average hits: 2.45\n",
    "Average max hits: 10.00\n",
    "Accuracy: 0.96\n",
    "Accuracy (top-2): 0.98\n",
    "Accuracy (top-3): 1.00\n",
    "Average number of candidates (top-1): 60.18\n",
    "Average number of candidates (top-2): 156.35\n",
    "Average number of candidates (top-3): 274.72`\n",
    "- `(num_proj=10, bucket_size=0.1, b=1)\n",
    "Average hits: 3.03\n",
    "Average max hits: 10.00\n",
    "Accuracy: 1.00\n",
    "Accuracy (top-2): 1.00\n",
    "Accuracy (top-3): 1.00\n",
    "Average number of candidates (top-1): 146.22\n",
    "Average number of candidates (top-2): 308.81\n",
    "Average number of candidates (top-3): 473.29`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    }
   },
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "Voronoi diagram is not 2-D",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-57-a90c00dcd72e>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;31m# only work with dim = 2\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mfig\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvoronoi_plot_2d\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mshow_vertices\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mshow_points\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      4\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscatter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'blue'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscatter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mx_adv\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mx_adv\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'green'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda/envs/py36/lib/python3.6/site-packages/scipy/spatial/_plotutils.py\u001b[0m in \u001b[0;36mvoronoi_plot_2d\u001b[0;34m(vor, ax, **kw)\u001b[0m\n",
      "\u001b[0;32m~/miniconda/envs/py36/lib/python3.6/site-packages/scipy/spatial/_plotutils.py\u001b[0m in \u001b[0;36m_held_figure\u001b[0;34m(func, obj, ax, **kw)\u001b[0m\n\u001b[1;32m     14\u001b[0m         \u001b[0mfig\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfigure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     15\u001b[0m         \u001b[0max\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgca\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 16\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0max\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0max\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkw\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     17\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     18\u001b[0m     \u001b[0;31m# As of matplotlib 2.0, the \"hold\" mechanism is deprecated.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda/envs/py36/lib/python3.6/site-packages/scipy/spatial/_plotutils.py\u001b[0m in \u001b[0;36mvoronoi_plot_2d\u001b[0;34m(vor, ax, **kw)\u001b[0m\n\u001b[1;32m    213\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    214\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mvor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpoints\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 215\u001b[0;31m         \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Voronoi diagram is not 2-D\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    216\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    217\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mkw\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'show_points'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mValueError\u001b[0m: Voronoi diagram is not 2-D"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD8CAYAAAB0IB+mAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAANgElEQVR4nO3ccYjfd33H8efLxE6mtY7lBEmi7Vi6Gsqg7ug6hFnRjbR/JP8USaC4SmnArQ5mETocKvWvKUMQsmm2iVPQWv1DD4nkD1fpECO50lmalMAtOnNE6Fm7/lO0Znvvj99P77hcct/e/e4u3vv5gMDv+/t9fr9758PdM798f/f7paqQJG1/r9rqASRJm8PgS1ITBl+SmjD4ktSEwZekJgy+JDWxavCTfC7Jc0meucLtSfLpJHNJnk7ytsmPKUlaryHP8D8PHLjK7XcB+8Z/jgL/tP6xJEmTtmrwq+oJ4GdXWXII+EKNnALekORNkxpQkjQZOyfwGLuBC0uO58fX/WT5wiRHGf0vgNe+9rV/dMstt0zgy0tSH08++eRPq2pqLfedRPCzwnUrfl5DVR0HjgNMT0/X7OzsBL68JPWR5L/Xet9J/JbOPLB3yfEe4OIEHleSNEGTCP4M8N7xb+vcAbxYVZedzpEkba1VT+kk+TJwJ7AryTzwUeDVAFX1GeAEcDcwB7wEvG+jhpUkrd2qwa+qI6vcXsBfTWwiSdKG8J22ktSEwZekJgy+JDVh8CWpCYMvSU0YfElqwuBLUhMGX5KaMPiS1ITBl6QmDL4kNWHwJakJgy9JTRh8SWrC4EtSEwZfkpow+JLUhMGXpCYMviQ1YfAlqQmDL0lNGHxJasLgS1ITBl+SmjD4ktSEwZekJgy+JDVh8CWpCYMvSU0YfElqwuBLUhMGX5KaMPiS1ITBl6QmDL4kNTEo+EkOJDmXZC7Jwyvc/uYkjyd5KsnTSe6e/KiSpPVYNfhJdgDHgLuA/cCRJPuXLfs74LGqug04DPzjpAeVJK3PkGf4twNzVXW+ql4GHgUOLVtTwOvHl28ALk5uREnSJAwJ/m7gwpLj+fF1S30MuDfJPHAC+MBKD5TkaJLZJLMLCwtrGFeStFZDgp8Vrqtlx0eAz1fVHuBu4ItJLnvsqjpeVdNVNT01NfXKp5UkrdmQ4M8De5cc7+HyUzb3A48BVNX3gNcAuyYxoCRpMoYE/zSwL8lNSa5j9KLszLI1PwbeBZDkrYyC7zkbSbqGrBr8qroEPAicBJ5l9Ns4Z5I8kuTgeNlDwANJfgB8Gbivqpaf9pEkbaGdQxZV1QlGL8Yuve4jSy6fBd4+2dEkSZPkO20lqQmDL0lNGHxJasLgS1ITBl+SmjD4ktSEwZekJgy+JDVh8CWpCYMvSU0YfElqwuBLUhMGX5KaMPiS1ITBl6QmDL4kNWHwJakJgy9JTRh8SWrC4EtSEwZfkpow+JLUhMGXpCYMviQ1YfAlqQmDL0lNGHxJasLgS1ITBl+SmjD4ktSEwZekJgy+JDVh8CWpCYMvSU0MCn6SA0nOJZlL8vAV1rwnydkkZ5J8abJjSpLWa+dqC5LsAI4BfwbMA6eTzFTV2SVr9gF/C7y9ql5I8saNGliStDZDnuHfDsxV1fmqehl4FDi0bM0DwLGqegGgqp6b7JiSpPUaEvzdwIUlx/Pj65a6Gbg5yXeTnEpyYKUHSnI0yWyS2YWFhbVNLElakyHBzwrX1bLjncA+4E7gCPAvSd5w2Z2qjlfVdFVNT01NvdJZJUnrMCT488DeJcd7gIsrrPlGVf2yqn4InGP0D4Ak6RoxJPingX1JbkpyHXAYmFm25uvAOwGS7GJ0iuf8JAeVJK3PqsGvqkvAg8BJ4Fngsao6k+SRJAfHy04Czyc5CzwOfKiqnt+ooSVJr1yqlp+O3xzT09M1Ozu7JV9bkn5TJXmyqqbXcl/faStJTRh8SWrC4EtSEwZfkpow+JLUhMGXpCYMviQ1YfAlqQmDL0lNGHxJasLgS1ITBl+SmjD4ktSEwZekJgy+JDVh8CWpCYMvSU0YfElqwuBLUhMGX5KaMPiS1ITBl6QmDL4kNWHwJakJgy9JTRh8SWrC4EtSEwZfkpow+JLUhMGXpCYMviQ1YfAlqQmDL0lNGHxJasLgS1ITg4Kf5ECSc0nmkjx8lXX3JKkk05MbUZI0CasGP8kO4BhwF7AfOJJk/wrrrgf+Gvj+pIeUJK3fkGf4twNzVXW+ql4GHgUOrbDu48AngJ9PcD5J0oQMCf5u4MKS4/nxdb+W5DZgb1V982oPlORoktkkswsLC694WEnS2g0Jfla4rn59Y/Iq4FPAQ6s9UFUdr6rpqpqempoaPqUkad2GBH8e2LvkeA9wccnx9cCtwHeS/Ai4A5jxhVtJurYMCf5pYF+Sm5JcBxwGZn51Y1W9WFW7qurGqroROAUcrKrZDZlYkrQmqwa/qi4BDwIngWeBx6rqTJJHkhzc6AElSZOxc8iiqjoBnFh23UeusPbO9Y8lSZo032krSU0YfElqwuBLUhMGX5KaMPiS1ITBl6QmDL4kNWHwJakJgy9JTRh8SWrC4EtSEwZfkpow+JLUhMGXpCYMviQ1YfAlqQmDL0lNGHxJasLgS1ITBl+SmjD4ktSEwZekJgy+JDVh8CWpCYMvSU0YfElqwuBLUhMGX5KaMPiS1ITBl6QmDL4kNWHwJakJgy9JTRh8SWpiUPCTHEhyLslckodXuP2DSc4meTrJt5O8ZfKjSpLWY9XgJ9kBHAPuAvYDR5LsX7bsKWC6qv4Q+BrwiUkPKklanyHP8G8H5qrqfFW9DDwKHFq6oKoer6qXxoengD2THVOStF5Dgr8buLDkeH583ZXcD3xrpRuSHE0ym2R2YWFh+JSSpHUbEvyscF2tuDC5F5gGPrnS7VV1vKqmq2p6ampq+JSSpHXbOWDNPLB3yfEe4OLyRUneDXwYeEdV/WIy40mSJmXIM/zTwL4kNyW5DjgMzCxdkOQ24LPAwap6bvJjSpLWa9XgV9Ul4EHgJPAs8FhVnUnySJKD42WfBF4HfDXJfyaZucLDSZK2yJBTOlTVCeDEsus+suTyuyc8lyRpwnynrSQ1YfAlqQmDL0lNGHxJasLgS1ITBl+SmjD4ktSEwZekJgy+JDVh8CWpCYMvSU0YfElqwuBLUhMGX5KaMPiS1ITBl6QmDL4kNWHwJakJgy9JTRh8SWrC4EtSEwZfkpow+JLUhMGXpCYMviQ1YfAlqQmDL0lNGHxJasLgS1ITBl+SmjD4ktSEwZekJgy+JDVh8CWpCYMvSU0MCn6SA0nOJZlL8vAKt/9Wkq+Mb/9+khsnPagkaX1WDX6SHcAx4C5gP3Akyf5ly+4HXqiq3wc+Bfz9pAeVJK3PkGf4twNzVXW+ql4GHgUOLVtzCPi38eWvAe9KksmNKUlar50D1uwGLiw5ngf++EprqupSkheB3wV+unRRkqPA0fHhL5I8s5aht6FdLNurxtyLRe7FIvdi0R+s9Y5Dgr/SM/Vawxqq6jhwHCDJbFVND/j62557sci9WOReLHIvFiWZXet9h5zSmQf2LjneA1y80pokO4EbgJ+tdShJ0uQNCf5pYF+Sm5JcBxwGZpatmQH+Ynz5HuDfq+qyZ/iSpK2z6imd8Tn5B4GTwA7gc1V1JskjwGxVzQD/CnwxyRyjZ/aHB3zt4+uYe7txLxa5F4vci0XuxaI170V8Ii5JPfhOW0lqwuBLUhMbHnw/lmHRgL34YJKzSZ5O8u0kb9mKOTfDanuxZN09SSrJtv2VvCF7keQ94++NM0m+tNkzbpYBPyNvTvJ4kqfGPyd3b8WcGy3J55I8d6X3KmXk0+N9ejrJ2wY9cFVt2B9GL/L+F/B7wHXAD4D9y9b8JfCZ8eXDwFc2cqat+jNwL94J/Pb48vs778V43fXAE8ApYHqr597C74t9wFPA74yP37jVc2/hXhwH3j++vB/40VbPvUF78afA24BnrnD73cC3GL0H6g7g+0Med6Of4fuxDItW3YuqeryqXhofnmL0noftaMj3BcDHgU8AP9/M4TbZkL14ADhWVS8AVNVzmzzjZhmyFwW8fnz5Bi5/T9C2UFVPcPX3Mh0CvlAjp4A3JHnTao+70cFf6WMZdl9pTVVdAn71sQzbzZC9WOp+Rv+Cb0er7kWS24C9VfXNzRxsCwz5vrgZuDnJd5OcSnJg06bbXEP24mPAvUnmgRPABzZntGvOK+0JMOyjFdZjYh/LsA0M/nsmuReYBt6xoRNtnavuRZJXMfrU1fs2a6AtNOT7Yiej0zp3Mvpf338kubWq/meDZ9tsQ/biCPD5qvqHJH/C6P0/t1bV/238eNeUNXVzo5/h+7EMi4bsBUneDXwYOFhVv9ik2TbbantxPXAr8J0kP2J0jnJmm75wO/Rn5BtV9cuq+iFwjtE/ANvNkL24H3gMoKq+B7yG0QerdTOoJ8ttdPD9WIZFq+7F+DTGZxnFfruep4VV9qKqXqyqXVV1Y1XdyOj1jINVteYPjbqGDfkZ+TqjF/RJsovRKZ7zmzrl5hiyFz8G3gWQ5K2Mgr+wqVNeG2aA945/W+cO4MWq+slqd9rQUzq1cR/L8Btn4F58Engd8NXx69Y/rqqDWzb0Bhm4Fy0M3IuTwJ8nOQv8L/Chqnp+66beGAP34iHgn5P8DaNTGPdtxyeISb7M6BTervHrFR8FXg1QVZ9h9PrF3cAc8BLwvkGPuw33SpK0At9pK0lNGHxJasLgS1ITBl+SmjD4ktSEwZekJgy+JDXx/4aZaro1YsjCAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# only work with dim = 2\n",
    "\n",
    "fig = voronoi_plot_2d(vor, show_vertices=False, show_points=True)\n",
    "plt.scatter([x[0]], [x[1]], c='blue')\n",
    "plt.scatter([x_adv[0]], [x_adv[1]], c='green')\n",
    "for c in candidates:\n",
    "    plt.scatter(c[:, 0], c[:, 1], c='red')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.14434432983398438 32.33 96\n"
     ]
    }
   ],
   "source": [
    "num_points = 100\n",
    "dim = 6\n",
    "\n",
    "seed = 123\n",
    "np.random.seed(seed)\n",
    "\n",
    "# sample points from unifromly random distribution inside [0, 1] box\n",
    "points = np.random.uniform(size=(num_points, dim))\n",
    "test_points = np.random.uniform(size=(num_points, dim))\n",
    "\n",
    "# construct Voronoi diagram\n",
    "vor = Voronoi(points)\n",
    "\n",
    "params = {'dim': dim,\n",
    "          'num_proj': 50,\n",
    "          'bucket_size': 0.1}\n",
    "lsh = VorLSH(vor, params, seed=seed)\n",
    "\n",
    "qt = 0\n",
    "num_top1 = 0\n",
    "num_correct = 0\n",
    "\n",
    "for i in range(num_points):\n",
    "    \n",
    "    x_adv = find_nearest_ridge(test_points[i], vor)\n",
    "    start = time.time()\n",
    "    out = lsh.query_nearby(test_points[i], 1)\n",
    "    qt += time.time() - start\n",
    "    counts = out.sum(0)\n",
    "    top_counts = np.sort(np.unique(counts))[::-1]\n",
    "    ind_top1 = np.nonzero(top_counts[0] == counts)[0]\n",
    "    vert_top1 = lsh.ridge_to_vertices(ind_top1)\n",
    "    match_top1 = check_candidates(test_points[i], x_adv, vert_top1)\n",
    "\n",
    "    num_correct += (match_top1 is not None)\n",
    "    num_top1 += len(ind_top1)\n",
    "    \n",
    "print(qt, num_top1 / num_points, num_correct)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = pickle.load(open('exact_dim2_ns100_seed123.p', 'rb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_points = 100\n",
    "dim = 2\n",
    "seed = 123\n",
    "np.random.seed(seed)\n",
    "points = np.random.uniform(size=(num_points, dim))\n",
    "test_points = np.random.uniform(size=(num_points, dim))\n",
    "\n",
    "vor = Voronoi(points)\n",
    "params = {'dim': dim,\n",
    "          'num_proj': 50,\n",
    "          'bucket_size': 0.1}\n",
    "lsh = VorLSH(vor, params, seed=seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_adv = find_nearest_ridge(test_points[1], vor)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.65336487, 0.99608633])"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_points[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[[0.54263593, 0.06677444],\n",
       "        [0.65336487, 0.99608633],\n",
       "        [0.76939734, 0.57377411],\n",
       "        [0.10263526, 0.69983407],\n",
       "        [0.66116787, 0.04909713],\n",
       "        [0.79229934, 0.51871658],\n",
       "        [0.42586769, 0.78818717],\n",
       "        [0.41156922, 0.48102628],\n",
       "        [0.18162884, 0.3213189 ],\n",
       "        [0.845533  , 0.18690375],\n",
       "        [0.41729108, 0.98903452],\n",
       "        [0.23660025, 0.91683267],\n",
       "        [0.91839747, 0.09129634],\n",
       "        [0.46365271, 0.50221629],\n",
       "        [0.31366895, 0.04733954],\n",
       "        [0.24168564, 0.09552964],\n",
       "        [0.23824991, 0.80779108],\n",
       "        [0.89497826, 0.04322289],\n",
       "        [0.30194679, 0.98058218],\n",
       "        [0.53950482, 0.62630936],\n",
       "        [0.00554541, 0.48490944],\n",
       "        [0.98833874, 0.37519884],\n",
       "        [0.09703816, 0.46190876],\n",
       "        [0.96300447, 0.34183062],\n",
       "        [0.79892276, 0.79884633],\n",
       "        [0.20824829, 0.4433677 ],\n",
       "        [0.71560128, 0.41051979],\n",
       "        [0.19100692, 0.96749431],\n",
       "        [0.65075035, 0.86545983],\n",
       "        [0.02524236, 0.26690581],\n",
       "        [0.50207105, 0.06744865],\n",
       "        [0.99303326, 0.2364624 ],\n",
       "        [0.37429218, 0.21401192],\n",
       "        [0.10544585, 0.23247981],\n",
       "        [0.30060999, 0.63444212],\n",
       "        [0.2812348 , 0.36227676],\n",
       "        [0.00594284, 0.36571913],\n",
       "        [0.53388598, 0.16201584],\n",
       "        [0.59743312, 0.29315246],\n",
       "        [0.63205049, 0.0261966 ],\n",
       "        [0.88759345, 0.01611863],\n",
       "        [0.12695866, 0.77716183],\n",
       "        [0.04589522, 0.7109987 ],\n",
       "        [0.97104614, 0.87168293],\n",
       "        [0.71016168, 0.95850983],\n",
       "        [0.42981334, 0.87287892],\n",
       "        [0.35595755, 0.92976363],\n",
       "        [0.14877766, 0.94002901],\n",
       "        [0.8327162 , 0.84605484],\n",
       "        [0.12392315, 0.59648675],\n",
       "        [0.01639248, 0.72118437],\n",
       "        [0.00773749, 0.08482227],\n",
       "        [0.22549843, 0.87512453],\n",
       "        [0.36357632, 0.53995994],\n",
       "        [0.56810321, 0.22546336],\n",
       "        [0.57214678, 0.66095181],\n",
       "        [0.2982454 , 0.41862686],\n",
       "        [0.45308883, 0.9323505 ],\n",
       "        [0.58749372, 0.94825236],\n",
       "        [0.5560349 , 0.50056139],\n",
       "        [0.00353221, 0.48088905],\n",
       "        [0.927455  , 0.19836569],\n",
       "        [0.05209113, 0.40677889],\n",
       "        [0.37239659, 0.85715302],\n",
       "        [0.02661065, 0.92014974],\n",
       "        [0.68090301, 0.90422601],\n",
       "        [0.60752901, 0.8119533 ],\n",
       "        [0.33554389, 0.34956625],\n",
       "        [0.38987423, 0.75479708],\n",
       "        [0.36929117, 0.24221981],\n",
       "        [0.9376684 , 0.90801084],\n",
       "        [0.3487973 , 0.63463819],\n",
       "        [0.27384222, 0.20611513],\n",
       "        [0.33633953, 0.32709988],\n",
       "        [0.8822761 , 0.82230381],\n",
       "        [0.70962326, 0.95934531],\n",
       "        [0.42254337, 0.24503303],\n",
       "        [0.11739844, 0.30105336],\n",
       "        [0.14526373, 0.0921861 ],\n",
       "        [0.6029322 , 0.36418745],\n",
       "        [0.56457036, 0.19133571],\n",
       "        [0.67690583, 0.21550556],\n",
       "        [0.2780228 , 0.74175911],\n",
       "        [0.5597379 , 0.33483641],\n",
       "        [0.54298831, 0.69398455],\n",
       "        [0.91213068, 0.58070885],\n",
       "        [0.23268638, 0.74669763],\n",
       "        [0.77776904, 0.20040136],\n",
       "        [0.82057548, 0.46493699],\n",
       "        [0.77976666, 0.23747822],\n",
       "        [0.33258027, 0.95369712],\n",
       "        [0.65781507, 0.77287782],\n",
       "        [0.68837434, 0.20430412],\n",
       "        [0.47068875, 0.80896387],\n",
       "        [0.67503513, 0.00602789],\n",
       "        [0.08740774, 0.34679472],\n",
       "        [0.94436554, 0.49119048],\n",
       "        [0.27017627, 0.36042372],\n",
       "        [0.21065262, 0.42120006],\n",
       "        [0.21803544, 0.84575251]]])"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "points = np.random.uniform(size=(5, 3))\n",
    "vor = Voronoi(points)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[0, 1, -1], [0, 1, -1], [-1, 0, 1]]"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vor.ridge_vertices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "points0 = np.array([[0, 0, 0], [0, 0, 1], [0, 0, -1],\n",
    "                    [0, 1, 0], [0, 1, 1], [0, 1, -1],\n",
    "                    [0, -1, 0], [0, -1, 1], [0, -1, -1]])\n",
    "points1 = np.copy(points0)\n",
    "points1[:, 0] += 1\n",
    "points2 = np.copy(points0)\n",
    "points2[:, 0] -= 1\n",
    "points = np.concatenate([points0, points1, points2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "vor = Voronoi(points)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[0, 1, -1],\n",
       " [0, -1, 2],\n",
       " [0, 1, 3, 2],\n",
       " [2, 3, -1],\n",
       " [1, 3, -1],\n",
       " [1, 3, -1],\n",
       " [0, 2, -1],\n",
       " [0, 1, -1],\n",
       " [0, 4, -1],\n",
       " [0, 4, 5, 1],\n",
       " [1, 5, -1],\n",
       " [-1, 5, 4],\n",
       " [1, 5, -1],\n",
       " [0, 4, -1],\n",
       " [-1, 6, 7],\n",
       " [6, 7, -1],\n",
       " [2, 3, -1],\n",
       " [-1, 2, 7],\n",
       " [3, 2, 7, 6],\n",
       " [6, 3, -1],\n",
       " [4, -1, 5],\n",
       " [4, 7, -1],\n",
       " [4, 7, 6, 5],\n",
       " [5, 6, -1],\n",
       " [6, 3, -1],\n",
       " [2, 7, -1],\n",
       " [6, -1, 5],\n",
       " [0, 4, 7, 2],\n",
       " [1, 3, 6, 5],\n",
       " [-1, 4, 7]]"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vor.ridge_vertices\n",
    "\n",
    "# ridge_vertices excludes infinite ridge with only one vertex\n",
    "\n",
    "# ridge_points -> "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 0.,  1.,  0.],\n",
       "       [-1.,  1.,  0.]])"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vor.points[vor.ridge_points[1]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([-0.5,  0.5,  0.5]), array([-0.5,  0.5, -0.5]))"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vor.vertices[0], vor.vertices[2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[-0.5  0.5  0.5]\n",
      "[-0.5 -0.5  0.5]\n",
      "[ 0.5 -0.5  0.5]\n",
      "[0.5 0.5 0.5]\n"
     ]
    }
   ],
   "source": [
    "for i in [0, 4, 5, 1]:\n",
    "    print(vor.vertices[i])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[-1, 0, 1, 2, 3],\n",
       " [-1, 1, 3],\n",
       " [],\n",
       " [-1, 0, 2],\n",
       " [-1, 0, 1, 4, 5],\n",
       " [-1, 1],\n",
       " [-1, 1, 5],\n",
       " [-1, 0, 1],\n",
       " [-1, 0, 4],\n",
       " [-1, 0],\n",
       " [-1, 6, 7],\n",
       " [-1, 7],\n",
       " [-1, 6],\n",
       " [-1, 2, 3, 6, 7],\n",
       " [-1, 4, 5, 6, 7],\n",
       " [-1, 3],\n",
       " [-1, 3, 6],\n",
       " [-1, 2],\n",
       " [-1, 2, 3],\n",
       " [-1, 2, 7],\n",
       " [-1, 5, 6],\n",
       " [0, 1, 2, 3, 4, 5, 6, 7],\n",
       " [-1, 1, 3, 5, 6],\n",
       " [-1, 4, 7],\n",
       " [-1, 0, 2, 4, 7],\n",
       " [-1, 4, 5],\n",
       " [-1, 5],\n",
       " [-1, 4]]"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vor.regions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[-1, 0, 1, 2, 3],\n",
       " [-1, 1, 3],\n",
       " [],\n",
       " [-1, 0, 2],\n",
       " [-1, 0, 1, 4, 5],\n",
       " [-1, 1],\n",
       " [-1, 1, 5],\n",
       " [-1, 0, 1],\n",
       " [-1, 0, 4],\n",
       " [-1, 0],\n",
       " [-1, 6, 7],\n",
       " [-1, 7],\n",
       " [-1, 6],\n",
       " [-1, 2, 3, 6, 7],\n",
       " [-1, 4, 5, 6, 7],\n",
       " [-1, 3],\n",
       " [-1, 3, 6],\n",
       " [-1, 2],\n",
       " [-1, 2, 3],\n",
       " [-1, 2, 7],\n",
       " [-1, 5, 6],\n",
       " [0, 1, 2, 3, 4, 5, 6, 7],\n",
       " [-1, 1, 3, 5, 6],\n",
       " [-1, 4, 7],\n",
       " [-1, 0, 2, 4, 7],\n",
       " [-1, 4, 5],\n",
       " [-1, 5],\n",
       " [-1, 4]]"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vor.regions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "# Halfspace Intersection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.spatial import HalfspaceIntersection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# [A; b]\n",
    "dim = 10\n",
    "num_planes = 1000\n",
    "\n",
    "halfspaces = np.random.randn(num_planes, dim)\n",
    "halfspaces = np.concatenate([halfspaces, - np.ones((num_planes, 1))], axis=1)\n",
    "\n",
    "feasible_point = np.zeros(dim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "755.7811999320984\n"
     ]
    }
   ],
   "source": [
    "start = time.time()\n",
    "hs = HalfspaceIntersection(halfspaces, feasible_point)\n",
    "end = time.time()\n",
    "print(end - start)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
