{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "env: CUDA_VISIBLE_DEVICES=1\n"
     ]
    }
   ],
   "source": [
    "%env CUDA_VISIBLE_DEVICES=1\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "from tqdm.notebook import tqdm\n",
    "import os\n",
    "from utils import write_fvecs, write_ivecs, read_fvecs, read_ivecs, read_dvecs\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sphere_d6\n"
     ]
    }
   ],
   "source": [
    "n = 1010000\n",
    "k = 100\n",
    "d = 6\n",
    "\n",
    "set_path = f'sphere_d{d}'\n",
    "print(set_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "base = read_dvecs(os.path.join('gbnns_theory/data', set_path, 'base.1M.dvecs'))\n",
    "query = read_dvecs(os.path.join('gbnns_theory/data', set_path, 'query.10K.dvecs'))\n",
    "gt = read_ivecs(os.path.join('gbnns_theory/data', set_path, 'gt.10K.ivecs'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# edges = read_ivecs(os.path.join('gbnns_theory/models', set_path, 'hyperbolic_d4_X100_knn.ivecs'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b15b2ba04fa7424ab48f7e46d1660bce",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2000.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "batch_size = 500\n",
    "knn = torch.zeros((len(base), 300), dtype=torch.int32)\n",
    "t_base = torch.as_tensor(base).cuda()\n",
    "t_base_t = t_base.t()\n",
    "for i in tqdm(range(0, len(base), batch_size)):\n",
    "    dists = t_base[i: i + batch_size] @ t_base_t\n",
    "    knn[i: i + batch_size] = dists.topk(301, dim=1, largest=True)[1].cpu()[:, 1:]\n",
    "\n",
    "knn = knn.numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a0b51b5847c442499c2e0b1b8a67257d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=20.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "batch_size = 500\n",
    "gt = torch.zeros((len(query), 100), dtype=torch.int32)\n",
    "\n",
    "t_query = torch.as_tensor(query).clone()\n",
    "t_base_t = torch.as_tensor(base).t().clone().cuda()\n",
    "\n",
    "for i in tqdm(range(0, len(query), batch_size)):\n",
    "    dists = t_query[i: i + batch_size].cuda() @ t_base_t\n",
    "    gt[i: i + batch_size] = dists.topk(100, dim=1, largest=True)[1].cpu()\n",
    "\n",
    "gt = gt.numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "write_ivecs(os.path.join('gbnns_theory/data', set_path, 'gt.10K.ivecs'), gt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs(os.path.join('gbnns_theory/models', set_path), exist_ok=True)\n",
    "write_ivecs(os.path.join('gbnns_theory/models', set_path, 'knn.ivecs'), knn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "cpp_base = read_fvecs(os.path.join('gbnns_theory/data/synthetic/synthetic_database_n_10_6_d_5.fvecs'))\n",
    "cpp_query = read_fvecs(os.path.join('gbnns_theory/data/synthetic/synthetic_query_n_10_4_d_5.fvecs'))\n",
    "cpp_gt = read_ivecs(os.path.join('gbnns_theory/data/synthetic/synthetic_groundtruth_n_10_4_d_5.ivecs'))\n",
    "edges = read_ivecs(os.path.join('gbnns_theory/data/synthetic/synthetic_groundtruth_n_10_4_d_5.ivecs'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "dots = cpp_query @ cpp_base.transpose()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([241321, 331270, 243370, ..., 499704, 699255, 608603])"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dots.argmax(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[241321],\n",
       "       [331270],\n",
       "       [243370],\n",
       "       ...,\n",
       "       [499704],\n",
       "       [699255],\n",
       "       [608603]], dtype=int32)"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cpp_gt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([808004, 163266,   2178, 235115, 497946, 281246, 673540, 369055,\n",
       "       645141, 103458, 700435, 760342, 568457, 554952, 688674, 263605,\n",
       "       582387, 201398, 988447, 171143, 671075, 272040, 497013, 587664,\n",
       "        33044, 900534, 317352,  55123, 190976, 559944, 801642, 647259,\n",
       "       578615, 936133,  68184, 410139, 803571, 936290, 723952,  29175,\n",
       "       363503, 568983, 534858, 510112, 370327, 318955, 239975, 767851,\n",
       "       694063, 172291, 408236, 809848, 874292, 863527, 733568, 912418,\n",
       "       265000, 444846, 980332, 687217, 656623, 953541, 165436, 225608,\n",
       "       102889, 815682, 603273, 700773, 983627, 667117, 816796, 688689,\n",
       "       441486,  25538, 275313, 202864, 617844, 378004, 950287, 480209,\n",
       "       597789, 226628, 712190, 393227, 563759, 558459, 816895, 607852,\n",
       "       621942, 788256, 659582, 426229, 177875, 438870, 995158, 468675,\n",
       "       742055, 200147, 788257, 782984, 247077, 113214, 601611, 297580,\n",
       "       138098, 834447, 403731, 289127, 641132, 213120, 261138, 192064,\n",
       "       195532, 809292, 377161, 194867, 291545,  86013, 711740, 310618,\n",
       "        71682,   3243, 912996, 660758, 867230, 164497, 662908, 982131,\n",
       "        91717,  97966, 819027, 964330, 363057, 908228, 760881, 151932,\n",
       "       241329, 375262, 638249, 126900, 457188,  66288, 745732, 143515,\n",
       "       368384, 929126, 495439, 325964, 113460, 335994, 770521, 492623,\n",
       "       119459, 736064, 409833, 662442,  17225, 332401, 229879, 478514,\n",
       "       701577, 623371, 868663, 444002, 683311,   9664,   6318, 715441,\n",
       "       749064, 891864, 977861, 560874, 910271, 475889, 209486, 540349,\n",
       "       774608, 409629, 548389, 480716, 259708, 769909,  61510, 938155,\n",
       "       407981, 252497, 714505, 116801, 400206, 346754, 277789, 313419,\n",
       "       575132, 384121, 924167, 322034, 745773, 844331,  73898, 325092,\n",
       "        94929, 812155, 470031, 903545, 227294, 670238, 899158, 223385,\n",
       "       775129, 437355, 291997, 548946, 363460, 600658, 896732, 317406,\n",
       "       331744, 363646, 211267, 603310, 766680, 880397, 777072, 648053,\n",
       "       551911, 177048, 742496, 613859, 816760, 856202,  87476, 480029,\n",
       "       964530, 803658,  22319, 248970, 410675, 398569, 576885, 947447,\n",
       "       668835, 159900, 518896, 602366, 802994,  56769, 899362, 755508,\n",
       "        60152, 622942, 924649, 701892,  12644, 984324, 925737, 345800,\n",
       "       342238, 894336, 944994, 202743, 470099, 908256,  29691, 763109,\n",
       "       625212,  46874, 266614, 258391, 424856, 392261, 789771, 833537,\n",
       "       928134, 528446, 731867, 745147, 388027, 911990, 809479, 104233,\n",
       "       894816, 309787, 363397, 213357, 331192, 107173, 906454, 349284,\n",
       "       239659, 159471, 187942, 178611, 739655, 616867, 416669, 806295,\n",
       "        70191, 933975, 570916, 847032], dtype=int32)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "knn[0][:300][::-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([808004, 163266,   2178, 235115, 497946, 281246, 673540, 369055,\n",
       "       645141, 103458, 700435, 760342, 568457, 554952, 688674, 263605,\n",
       "       582387, 201398, 988447, 171143, 671075, 272040, 497013, 587664,\n",
       "        33044, 900534, 317352,  55123, 190976, 559944, 801642, 647259,\n",
       "       578615, 936133,  68184, 410139, 803571, 936290, 723952,  29175,\n",
       "       363503, 568983, 534858, 510112, 370327, 318955, 239975, 767851,\n",
       "       694063, 172291, 408236, 809848, 874292, 863527, 733568, 912418,\n",
       "       265000, 444846, 980332, 687217, 656623, 953541, 165436, 225608,\n",
       "       102889, 815682, 603273, 700773, 983627, 667117, 816796, 688689,\n",
       "       441486,  25538, 275313, 202864, 617844, 378004, 950287, 480209,\n",
       "       597789, 226628, 712190, 393227, 563759, 558459, 816895, 607852,\n",
       "       621942, 788256, 659582, 426229, 177875, 438870, 995158, 468675,\n",
       "       742055, 200147, 788257, 782984, 247077, 113214, 601611, 297580,\n",
       "       138098, 834447, 403731, 289127, 641132, 213120, 261138, 192064,\n",
       "       195532, 809292, 377161, 194867, 291545,  86013, 711740, 310618,\n",
       "        71682,   3243, 912996, 660758, 867230, 164497, 662908, 982131,\n",
       "        91717,  97966, 819027, 964330, 363057, 908228, 760881, 151932,\n",
       "       241329, 375262, 638249, 126900, 457188,  66288, 745732, 143515,\n",
       "       368384, 929126, 495439, 325964, 113460, 335994, 770521, 492623,\n",
       "       119459, 736064, 409833, 662442,  17225, 332401, 229879, 478514,\n",
       "       701577, 623371, 868663, 444002, 683311,   9664,   6318, 715441,\n",
       "       749064, 891864, 977861, 560874, 910271, 475889, 209486, 540349,\n",
       "       774608, 409629, 548389, 480716, 259708, 769909,  61510, 938155,\n",
       "       407981, 252497, 714505, 116801, 400206, 346754, 277789, 313419,\n",
       "       575132, 384121, 924167, 322034, 745773, 844331,  73898, 325092,\n",
       "        94929, 812155, 470031, 903545, 227294, 670238, 899158, 223385,\n",
       "       775129, 437355, 291997, 548946, 363460, 600658, 896732, 317406,\n",
       "       331744, 363646, 211267, 603310, 766680, 880397, 777072, 648053,\n",
       "       551911, 177048, 742496, 613859, 816760, 856202,  87476, 480029,\n",
       "       964530, 803658,  22319, 248970, 410675, 398569, 576885, 947447,\n",
       "       668835, 159900, 518896, 602366, 802994,  56769, 899362, 755508,\n",
       "        60152, 622942, 924649, 701892,  12644, 984324, 925737, 345800,\n",
       "       342238, 894336, 944994, 202743, 470099, 908256,  29691, 763109,\n",
       "       625212,  46874, 266614, 258391, 424856, 392261, 789771, 833537,\n",
       "       928134, 528446, 731867, 745147, 388027, 911990, 809479, 104233,\n",
       "       894816, 309787, 363397, 213357, 331192, 107173, 906454, 349284,\n",
       "       239659, 159471, 187942, 178611, 739655, 616867, 416669, 806295,\n",
       "        70191, 933975, 570916, 847032], dtype=int32)"
      ]
     },
     "execution_count": 68,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "edges[0]# - np.argsort(dists)[:300][::-1]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
