{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3bd7e6c4-34c8-4d60-9a4e-659b07ad53aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "import logging\n",
    "\n",
    "from wmd import WMD\n",
    "import numpy as np\n",
    "\n",
    "from transformers import AutoTokenizer, AutoModel\n",
    "import torch\n",
    "import json\n",
    "from tqdm.notebook import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "310e74e3-7aa7-419f-8510-e201c9759408",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of RobertaModel were not initialized from the model checkpoint at microsoft/graphcodebert-base and are newly initialized: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained('microsoft/graphcodebert-base')\n",
    "model = AutoModel.from_pretrained('microsoft/graphcodebert-base')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "c5f1c2e8-3e0d-4fa1-bb6c-2b46d3f7359b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(50265, 768)\n"
     ]
    }
   ],
   "source": [
    "embeddings = model.embeddings.word_embeddings.weight.detach().numpy().astype(np.float32)\n",
    "print(embeddings.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ae937065-86f7-4f8f-8b28-442d7d97b219",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10300 10300\n"
     ]
    }
   ],
   "source": [
    "path = \"../data/code-translation/java-C#/data/\"\n",
    "\n",
    "with open(path + 'train.java-cs.txt.java', 'r') as f:\n",
    "    java_lines = f.readlines()\n",
    "\n",
    "with open(path + 'train.java-cs.txt.cs', 'r') as f:\n",
    "    cs_lines = f.readlines()\n",
    "    \n",
    "print(len(java_lines), len(cs_lines))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d4ff9635-1ecb-42ae-a80c-d35033fe855e",
   "metadata": {},
   "outputs": [],
   "source": [
    "java_codes = {}\n",
    "cs_codes = {}\n",
    "\n",
    "for i in range(len(java_lines)):\n",
    "    id = f'Code-{i}'\n",
    "    java_codes[id] = java_lines[i]\n",
    "    cs_codes[id] = cs_lines[i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a6b25295-b63a-4f90-8db8-cb4abf7e23c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_hist(code):\n",
    "    \n",
    "    RM = ['Ċ', 'Ġ']\n",
    "    tokens = [x for x in tokenizer.tokenize(code) if x not in RM]\n",
    "    token_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
    "    \n",
    "    token_counts = {}\n",
    "    for id in token_ids:\n",
    "        token_counts[id] = token_counts.get(id, 0) + 1\n",
    "        \n",
    "    idxs = sorted(token_counts.keys())\n",
    "    weights = np.array([token_counts[x] for x in idxs], dtype=np.float32)\n",
    "    weights = weights / np.sum(weights)\n",
    "    return idxs, weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "0c8c84ee-d0cb-4756-8b9d-2156a2d1d36e",
   "metadata": {},
   "outputs": [],
   "source": [
    "nbow_java = {}\n",
    "\n",
    "for javaid, javacode in java_codes.items():\n",
    "    idxs, weights = get_hist(javacode)\n",
    "    nbow_java[javaid] = (javaid, idxs, weights)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "2d54eed4-2602-40f8-93d6-c3afd9a3414f",
   "metadata": {},
   "outputs": [],
   "source": [
    "nbow_cs = {}\n",
    "\n",
    "for csid, cscode in cs_codes.items():\n",
    "    idxs, weights = get_hist(cscode)\n",
    "    nbow_cs[csid] = (csid, idxs, weights)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "0acbe192-aa5c-43da-860d-c17f0a308729",
   "metadata": {},
   "outputs": [],
   "source": [
    "calc = WMD(embeddings, nbow_java, \n",
    "           vocabulary_min=1, vocabulary_max=500, \n",
    "           vocabulary_optimizer=None,\n",
    "          verbosity=logging.WARNING)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "7d727e8f-ddf5-4c2a-a65c-5b2ccac024c3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3c621ef9895a4226b9b00e043c6d20a4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/10300 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-10-fa7a5677ff77>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mkey\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mpbar\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      6\u001b[0m     \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwords\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweights\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnbow_cs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m     \u001b[0mres\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcalc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnearest_neighbors\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mwords\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweights\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      8\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      9\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mres\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mkey\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/wmd/lib/python3.8/site-packages/wmd/__init__.py\u001b[0m in \u001b[0;36mnearest_neighbors\u001b[0;34m(self, origin, k, early_stop, max_time, skipped_stop, throw)\u001b[0m\n\u001b[1;32m    560\u001b[0m                     \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_log\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minfo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"stopped by skipped_stop condition\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    561\u001b[0m                     \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 562\u001b[0;31m             estimated_d, w1, w2, dists = self._estimate_WMD_relaxation_batch(\n\u001b[0m\u001b[1;32m    563\u001b[0m                 words, weights, i2)\n\u001b[1;32m    564\u001b[0m             \u001b[0mfarthest\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0mneighbors\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/wmd/lib/python3.8/site-packages/wmd/__init__.py\u001b[0m in \u001b[0;36m_estimate_WMD_relaxation_batch\u001b[0;34m(self, words1, weights1, i2)\u001b[0m\n\u001b[1;32m    412\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_estimate_WMD_relaxation_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwords1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweights1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    413\u001b[0m         \u001b[0mjoint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mw1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mw2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_common_vocabulary_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mwords1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweights1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 414\u001b[0;31m         \u001b[0mw1\u001b[0m \u001b[0;34m/=\u001b[0m \u001b[0mw1\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    415\u001b[0m         \u001b[0mw2\u001b[0m \u001b[0;34m/=\u001b[0m \u001b[0mw2\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    416\u001b[0m         \u001b[0mevec\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0membeddings\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mjoint\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/wmd/lib/python3.8/site-packages/numpy/core/_methods.py\u001b[0m in \u001b[0;36m_sum\u001b[0;34m(a, axis, dtype, out, keepdims, initial, where)\u001b[0m\n\u001b[1;32m     45\u001b[0m def _sum(a, axis=None, dtype=None, out=None, keepdims=False,\n\u001b[1;32m     46\u001b[0m          initial=_NoValue, where=True):\n\u001b[0;32m---> 47\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0mumr_sum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkeepdims\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minitial\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwhere\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     48\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     49\u001b[0m def _prod(a, axis=None, dtype=None, out=None, keepdims=False,\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "corr, total = 0, 0\n",
    "\n",
    "pbar = tqdm(nbow_cs.keys())\n",
    "\n",
    "for key in pbar:\n",
    "    _, words, weights = nbow_cs[key]\n",
    "    res = calc.nearest_neighbors((words, weights))\n",
    "    \n",
    "    if res[0][0] == key:\n",
    "        corr += 1\n",
    "    total += 1\n",
    "    \n",
    "    acc = corr / float(total)\n",
    "    pbar.set_description(f'Acc: {acc:0.3f}')\n",
    "    \n",
    "print(corr, total, corr/float(total))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "b2f9f77b-8c7b-44d2-8dbd-56b56f91f895",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('Code-1720', 0.8467683792114258),\n",
       " ('Code-7795', 1.738513708114624),\n",
       " ('Code-4794', 1.921369194984436),\n",
       " ('Code-8870', 1.9376301765441895),\n",
       " ('Code-5624', 1.9744867086410522),\n",
       " ('Code-8576', 1.984363317489624),\n",
       " ('Code-8716', 2.0754191875457764),\n",
       " ('Code-9730', 2.077479839324951),\n",
       " ('Code-2124', 2.0802154541015625),\n",
       " ('Code-436', 2.0819151401519775)]"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53a85064-07e8-4f7c-83fe-0913c22f58fb",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
