{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os \n",
    "import numpy as np\n",
    "from utils import *\n",
    "from tqdm.notebook import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "filename = 'poincare_glove_models/poincare_glove_100D_cosh-dist-sq.txt'\n",
    "with open(filename, \"r\") as f:\n",
    "    model_name = f.readline()\n",
    "    \n",
    "    p_glove = []\n",
    "    for i in tqdm(range(189533)):\n",
    "        token, *emb = f.readline().split(' ')[:-1]\n",
    "        emb = [float(val) for val in emb]\n",
    "        p_glove.append(emb)\n",
    "\n",
    "p_glove = np.array(p_glove)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "filename = 'poincare_glove_models/vanilla_glove_100D.txt'\n",
    "with open(filename, \"r\") as f:\n",
    "    model_name = f.readline()\n",
    "    glove = []\n",
    "    for i in tqdm(range(189533)):\n",
    "        token, *emb = f.readline().split(' ')[:-1]\n",
    "        emb = [float(val) for val in emb]\n",
    "        glove.append(emb)\n",
    "        \n",
    "glove = np.array(glove)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "idxs = np.random.permutation(len(glove))\n",
    "base_idxs = idxs[-9533:]\n",
    "query_idxs = idxs[:9533]\n",
    "\n",
    "base_glove = glove[base_idxs]\n",
    "base_p_glove = p_glove[base_idxs]\n",
    "\n",
    "query_glove = glove[query_idxs]\n",
    "query_p_glove = p_glove[query_idxs]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "k = 100\n",
    "dist_query_p_glove = query_p_glove.copy()\n",
    "dist_query_p_glove[1:] = -dist_query_p_glove[1:]\n",
    "\n",
    "dists = dist_query_p_glove @ base_p_glove.transpose()\n",
    "gt_p_glove = np.zeros((len(query_p_glove), k)).astype('int64')\n",
    "for i in tqdm(range(len(query_p_glove))):\n",
    "    gt_p_glove[i] = np.argsort(dists[i])[:k]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dot = query_p_glove @ base_p_glove.transpose()\n",
    "sq_norm_q = (query_p_glove ** 2).sum(-1)\n",
    "sq_norm_b = (base_p_glove ** 2).sum(-1)\n",
    "sq_dists = sq_norm_q[:, None] - 2 * dot + sq_norm_b[None]\n",
    "dists = sq_dists / (1 - sq_norm_q[:, None]) / (1 - sq_norm_b[None]) \n",
    "\n",
    "gt_p_glove = np.zeros((len(query_p_glove), k)).astype('int64')\n",
    "for i in tqdm(range(len(query_p_glove))):\n",
    "    gt_p_glove[i] = np.argsort(dists[i])[:k]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "write_fvecs('poincare_glove_base.fvecs', base_p_glove)\n",
    "write_fvecs('poincare_glove_query.fvecs', query_p_glove)\n",
    "write_ivecs('poincare_glove_groundtruth.ivecs', gt_p_glove)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "k = 100\n",
    "dists = - query_glove @ base_glove.transpose()\n",
    "gt_glove = np.zeros((len(query_glove), k)).astype('int64')\n",
    "for i in tqdm(range(len(query_glove))):\n",
    "    gt_glove[i] = np.argsort(dists[i])[:k]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "write_fvecs('glove_base.fvecs', base_glove)\n",
    "write_fvecs('glove_query.fvecs', query_glove)\n",
    "write_ivecs('glove_groundtruth.ivecs', gt_glove)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
