{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ecf78deb-5012-4ffe-b925-564f5fda7be8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, AutoModel\n",
    "import torch\n",
    "import numpy as np\n",
    "import json\n",
    "from pyemd import emd\n",
    "from tqdm import tqdm\n",
    "from multiprocessing import Pool\n",
    "\n",
    "from sklearn.metrics import euclidean_distances\n",
    "from sklearn.preprocessing import normalize\n",
    "\n",
    "from wmd import WordMoversDistance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9f5879fd-4ae8-4468-bddf-578c72d10bbb",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(\"microsoft/codebert-base\")\n",
    "model = AutoModel.from_pretrained(\"microsoft/codebert-base\")\n",
    "embeddings = model.embeddings.word_embeddings.weight.detach().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "84b06557-8545-4088-bc8d-1faae62dc252",
   "metadata": {},
   "outputs": [],
   "source": [
    "wmd = WordMoversDistance(embeddings, n_jobs=-1, verbose=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d8e867e4-4d42-4a82-b428-24043e0bca15",
   "metadata": {},
   "outputs": [],
   "source": [
    "def tokenize(code):\n",
    "    rm_tokens = ['Ċ', 'Ġ']\n",
    "    tokens = [x for x in tokenizer.tokenize(code) if x not in rm_tokens]\n",
    "    token_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
    "    return token_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "8cee15ac-51af-477b-a180-96c7ab4ef8a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('../data/detok-tc-test-data/java.json', 'r') as f:\n",
    "    data_java = json.load(f)\n",
    "    \n",
    "with open('../data/detok-tc-test-data/python.json', 'r') as f:\n",
    "    data_py = json.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d3e89f33-e08a-44c6-aaf1-85584afcd1e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "ids_py, code_py = zip(*data_py.items())\n",
    "ids_java, code_java = zip(*data_java.items())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "2898fef6-511e-4067-a545-2a49f28699fb",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Token indices sequence length is longer than the specified maximum sequence length for this model (536 > 512). Running this sequence through the model will result in indexing errors\n"
     ]
    }
   ],
   "source": [
    "tokens_py = [tokenize(code) for code in code_py]\n",
    "tokens_java = [tokenize(code) for code in code_java]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "46b8c9e0-d0ab-4bdb-a46d-f6428df26e87",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(868, 868)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(tokens_py), len(tokens_java)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "3cb7a253-6bc8-497f-9403-709e99c8d217",
   "metadata": {},
   "outputs": [],
   "source": [
    "wmd.fit(ids_java, tokens_java)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "438f6af9-3fc7-4416-861e-1fddcfc43730",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "40bde4b8863c4a318b249421732e2c38",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/868 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-10-edab0bf3aae6>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      5\u001b[0m     \u001b[0mtoks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtokens_py\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      6\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m     \u001b[0mdists\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mwmd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtoks\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      8\u001b[0m     coderesult = {\n\u001b[1;32m      9\u001b[0m         \u001b[0mjavaid\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mdist\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mjavaid\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdist\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mwmd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msource_ids\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdists\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Documents/projects/wmd-codesim/experiments/wmd.py\u001b[0m in \u001b[0;36mpredict\u001b[0;34m(self, target_tokens)\u001b[0m\n\u001b[1;32m     70\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mpredict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget_tokens\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     71\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 72\u001b[0;31m         \u001b[0mdist\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_pairwise_wmd\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtarget_tokens\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     73\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mdist\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Documents/projects/wmd-codesim/experiments/wmd.py\u001b[0m in \u001b[0;36m_pairwise_wmd\u001b[0;34m(self, target_tokens)\u001b[0m\n\u001b[1;32m     60\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_pairwise_wmd\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget_tokens\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     61\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 62\u001b[0;31m         dist = Parallel(n_jobs=self.n_jobs, verbose=self.verbose)(\n\u001b[0m\u001b[1;32m     63\u001b[0m             \u001b[0mdelayed\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_wmd\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget_tokens\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     64\u001b[0m             \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mn_samples\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/wmd/lib/python3.8/site-packages/joblib/parallel.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, iterable)\u001b[0m\n\u001b[1;32m   1052\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1053\u001b[0m             \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_backend\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mretrieval_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1054\u001b[0;31m                 \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mretrieve\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1055\u001b[0m             \u001b[0;31m# Make sure that we get a last message telling us we are done\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1056\u001b[0m             \u001b[0melapsed_time\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_start_time\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/wmd/lib/python3.8/site-packages/joblib/parallel.py\u001b[0m in \u001b[0;36mretrieve\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    931\u001b[0m             \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    932\u001b[0m                 \u001b[0;32mif\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_backend\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'supports_timeout'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 933\u001b[0;31m                     \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_output\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mextend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjob\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtimeout\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtimeout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    934\u001b[0m                 \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    935\u001b[0m                     \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_output\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mextend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjob\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/wmd/lib/python3.8/site-packages/joblib/_parallel_backends.py\u001b[0m in \u001b[0;36mwrap_future_result\u001b[0;34m(future, timeout)\u001b[0m\n\u001b[1;32m    540\u001b[0m         AsyncResults.get from multiprocessing.\"\"\"\n\u001b[1;32m    541\u001b[0m         \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 542\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mfuture\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtimeout\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtimeout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    543\u001b[0m         \u001b[0;32mexcept\u001b[0m \u001b[0mCfTimeoutError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    544\u001b[0m             \u001b[0;32mraise\u001b[0m \u001b[0mTimeoutError\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/wmd/lib/python3.8/concurrent/futures/_base.py\u001b[0m in \u001b[0;36mresult\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m    432\u001b[0m                 \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__get_result\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    433\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 434\u001b[0;31m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_condition\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwait\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtimeout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    435\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    436\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_state\u001b[0m \u001b[0;32min\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mCANCELLED\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mCANCELLED_AND_NOTIFIED\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/wmd/lib/python3.8/threading.py\u001b[0m in \u001b[0;36mwait\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m    300\u001b[0m         \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m    \u001b[0;31m# restore state no matter what (e.g., KeyboardInterrupt)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    301\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mtimeout\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 302\u001b[0;31m                 \u001b[0mwaiter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0macquire\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    303\u001b[0m                 \u001b[0mgotit\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    304\u001b[0m             \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "result = {}\n",
    "\n",
    "for i in tqdm(range(len(ids_py))):\n",
    "    id = ids_py[i]\n",
    "    toks = tokens_py[i]\n",
    "    \n",
    "    dists = wmd.predict(toks)\n",
    "    coderesult = {\n",
    "        javaid: dist for (javaid, dist) in zip(wmd.source_ids, dists)\n",
    "    }\n",
    "    \n",
    "    result[id] = coderesult"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2942bb5e-b570-4aa8-a4f1-b1526e9900ca",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "with open('../results/TC-test-set/python-java-codebert-parallel.json', 'w') as f:\n",
    "    json.dump(result, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c000815-7730-41ab-9fe3-6215efe5c9dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "corr, tot = 0, 0\n",
    "\n",
    "for k, v in result.items():\n",
    "    java_sorted = sorted(v.items(), key=lambda item: item[1])[0]\n",
    "    \n",
    "    if k == java_sorted[0]:\n",
    "        corr += 1\n",
    "    tot += 1\n",
    "\n",
    "print(corr/float(tot), corr, tot)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d6c3876-a852-414c-af2d-699454a725dc",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "253c1673-3b93-4a90-ac4c-3d86da343cf7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a404f95-12ec-4bab-bd45-ade63f0977f6",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
