{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "env: CUDA_VISIBLE_DEVICES=0\n"
     ]
    }
   ],
   "source": [
    "%env CUDA_VISIBLE_DEVICES=0\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "from tqdm.notebook import tqdm\n",
    "import os\n",
    "from utils import write_dvecs, write_ivecs, read_dvecs, read_ivecs\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "hyperbolic_d6_X1.8\n"
     ]
    }
   ],
   "source": [
    "n = 1010000\n",
    "k = 100\n",
    "d = 6\n",
    "X = (float(n) ** (1 / (d-1))) / d\n",
    "_p = lambda x: (float(x) ** 2 - 1) ** (d - 2)\n",
    "# print(X, np.arccosh(X))\n",
    "X = 1.8\n",
    "set_path = f'hyperbolic_d{d}_X{X}'\n",
    "print(set_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# np.arccosh(np.array([1.01, 3, 12, 25, 100]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "base = read_dvecs(os.path.join('gbnns_theory/data', set_path, 'base.1M.dvecs'))\n",
    "query = read_dvecs(os.path.join('gbnns_theory/data', set_path, 'query.10K.dvecs'))\n",
    "# _gt = read_ivecs(os.path.join('gbnns_theory/data', set_path, 'gt.10K.ivecs'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d18e5cc3ae4848f89c47807a09c20361",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2000.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "batch_size = 500\n",
    "knn = torch.zeros((len(base), 300), dtype=torch.int32)\n",
    "knn_dists = torch.zeros((len(base), 300), dtype=torch.float32)\n",
    "\n",
    "t_dist_base = torch.as_tensor(base).clone()\n",
    "t_dist_base[:, 1:] = -1 * t_dist_base[:, 1:]\n",
    "t_base_t = torch.as_tensor(base).t().clone().cuda()\n",
    "\n",
    "\n",
    "for i in tqdm(range(0, len(base), batch_size)):\n",
    "    dists = t_dist_base[i: i + batch_size].cuda() @ t_base_t\n",
    "    topk = dists.topk(301, dim=1, largest=False)\n",
    "    \n",
    "    knn_dists[i: i + batch_size] = topk[0][:, 1:].cpu()\n",
    "    knn[i: i + batch_size] = topk[1][:, 1:].cpu()\n",
    "\n",
    "knn_dists = knn_dists.numpy()\n",
    "knn = knn.numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "81a686a5c2d54a118bbd60806e133c9d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=20.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "batch_size = 500\n",
    "gt = torch.zeros((len(query), 100), dtype=torch.int32)\n",
    "\n",
    "dist_query = query.copy()\n",
    "dist_query[:, 1:] = -dist_query[:, 1:]\n",
    "t_dist_query = torch.as_tensor(dist_query).clone()\n",
    "t_base_t = torch.as_tensor(base).t().clone().cuda()\n",
    "\n",
    "for i in tqdm(range(0, len(query), batch_size)):\n",
    "    dists = t_dist_query[i: i + batch_size].cuda() @ t_base_t\n",
    "    gt[i: i + batch_size] = dists.topk(100, dim=1, largest=False)[1].cpu()\n",
    "\n",
    "gt = gt.numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "write_ivecs(os.path.join('gbnns_theory/data', set_path, 'gt.10K.ivecs'), gt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs(os.path.join('gbnns_theory/models', set_path), exist_ok=True)\n",
    "write_ivecs(os.path.join('gbnns_theory/models', set_path, 'knn.ivecs'), knn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
