{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "93a0da81-b6d0-46c5-a693-f76c1e220f2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import logging\n",
    "import sys\n",
    "\n",
    "logger = logging.getLogger('WMD')\n",
    "logger.addHandler(logging.StreamHandler(sys.stdout))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4355fb5a-647d-4391-8a45-fb736f21f21d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from wmd import WMD\n",
    "import numpy as np\n",
    "\n",
    "from transformers import AutoTokenizer, AutoModel\n",
    "import torch\n",
    "import json\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "from sklearn.metrics import euclidean_distances\n",
    "from sklearn.preprocessing import normalize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5c2b7752-69d6-4674-80ea-e15c7f94e7ac",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of RobertaModel were not initialized from the model checkpoint at microsoft/graphcodebert-base and are newly initialized: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(\"microsoft/graphcodebert-base\")\n",
    "model = AutoModel.from_pretrained(\"microsoft/graphcodebert-base\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "34e311a6-1114-4c18-b373-edbe32e9d619",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1418 1418\n"
     ]
    }
   ],
   "source": [
    "with open('../data/detok-tc-test-data/java.json', 'r') as f:\n",
    "    javacodes = json.load(f)\n",
    "    \n",
    "with open('../data/detok-tc-test-data/python.json', 'r') as f:\n",
    "    pycodes = json.load(f)\n",
    "    \n",
    "print(len(javacodes), len(pycodes))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "fed5a8f0-13e1-4003-a83a-1dbe649ba897",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(50265, 768)\n"
     ]
    }
   ],
   "source": [
    "embeddings = model.embeddings.word_embeddings.weight.detach().numpy().astype(np.float32)\n",
    "print(embeddings.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "9b9119ab-77fa-470e-8e33-f12e9ba69603",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_hist(code):\n",
    "    \n",
    "    RM = ['Ċ', 'Ġ']\n",
    "    tokens = [x for x in tokenizer.tokenize(code) if x not in RM]\n",
    "    token_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
    "    \n",
    "    token_counts = {}\n",
    "    for id in token_ids:\n",
    "        token_counts[id] = token_counts.get(id, 0) + 1\n",
    "        \n",
    "    idxs = sorted(token_counts.keys())\n",
    "    weights = np.array([token_counts[x] for x in idxs], dtype=np.float32)\n",
    "    weights = weights / np.sum(weights)\n",
    "    return idxs, weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "53ea5c55-28e4-40bf-be1f-4a4933e0b1b9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Token indices sequence length is longer than the specified maximum sequence length for this model (693 > 512). Running this sequence through the model will result in indexing errors\n"
     ]
    }
   ],
   "source": [
    "nbow_java = {}\n",
    "\n",
    "for javaid, javacode in javacodes.items():\n",
    "    idxs, weights = get_hist(javacode)\n",
    "    nbow_java[javaid] = (javaid, idxs, weights)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "4ff20f67-ae38-419e-a1ed-25bf0dc0a763",
   "metadata": {},
   "outputs": [],
   "source": [
    "nbow_py = {}\n",
    "\n",
    "for pyid, pycode in pycodes.items():\n",
    "    idxs, weights = get_hist(pycode)\n",
    "    nbow_py[pyid] = (pyid, idxs, weights)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "cd347046-22f3-4f1d-ac07-b2c8797aa167",
   "metadata": {},
   "outputs": [],
   "source": [
    "calc = WMD(embeddings, nbow_java, \n",
    "           vocabulary_min=1, vocabulary_max=500, \n",
    "           vocabulary_optimizer=None,\n",
    "          verbosity=logging.WARNING)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "be755dc9-71bd-4319-ae8b-a5055f303a1e",
   "metadata": {},
   "outputs": [
    {
     "ename": "SyntaxError",
     "evalue": "EOL while scanning string literal (<ipython-input-10-a218de818a39>, line 13)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;36m  File \u001b[0;32m\"<ipython-input-10-a218de818a39>\"\u001b[0;36m, line \u001b[0;32m13\u001b[0m\n\u001b[0;31m    pbar.set_description(f'Accuracy: {acc*100:0.3f}'')\u001b[0m\n\u001b[0m                                                      ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m EOL while scanning string literal\n"
     ]
    }
   ],
   "source": [
    "corr, total = 0, 0\n",
    "\n",
    "with tqdm(nbow_py.keys()) as pbar:\n",
    "    for key in pbar:\n",
    "        _, words, weights = nbow_py[key]\n",
    "        res = calc.nearest_neighbors((words, weights))\n",
    "\n",
    "        if res[0][0] == key:\n",
    "            corr += 1\n",
    "        total += 1\n",
    "        \n",
    "        acc = corr / float(total)\n",
    "        pbar.set_description(f'Accuracy: {acc*100:0.3f}')\n",
    "    \n",
    "print(corr, total, corr/float(total))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d946e9bf-377c-4e7c-928c-de60bb033027",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
