{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "683d911d-4676-4aca-81a8-978dab710951",
   "metadata": {},
   "outputs": [],
   "source": [
    "from rank_bm25 import BM25Okapi\n",
    "\n",
    "import os\n",
    "import numpy as np\n",
    "import json\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "from transformers import AutoTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "48e272d6-af4f-460a-a351-ffa21717f120",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained('microsoft/graphcodebert-base')\n",
    "\n",
    "lang1 = 'java'\n",
    "lang2 = 'csharp'\n",
    "toktype = 'bert'\n",
    "result_folder = './java-csharp/bm25'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "be760f65-f104-473a-ab88-43d5e14ec6ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "def tokenize_bert(code):\n",
    "    RM = ['Ċ', 'Ġ']\n",
    "    tokens = [x for x in tokenizer.tokenize(code) if x not in RM]\n",
    "    return tokens\n",
    "\n",
    "def tokenize_simple(code):\n",
    "    return [x for x in code.lower().split(' ') if len(x) > 0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f0883d5f-89d6-4b87-b7af-d808f27d9fed",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_TC_java_py_data():\n",
    "    with open('../data/detok-tc-test-data/java.json', 'r') as f:\n",
    "        javacodes = json.load(f)\n",
    "\n",
    "    with open('../data/detok-tc-test-data/python.json', 'r') as f:\n",
    "        pycodes = json.load(f)\n",
    "        \n",
    "    return javacodes, pycodes\n",
    "\n",
    "\n",
    "def get_TC_java_cpp_data():\n",
    "    with open('../data/detok-tc-test-data/java.json', 'r') as f:\n",
    "        javacodes = json.load(f)\n",
    "\n",
    "    with open('../data/detok-tc-test-data/cpp.json', 'r') as f:\n",
    "        cppcodes = json.load(f)\n",
    "        \n",
    "    return javacodes, cppcodes\n",
    "\n",
    "\n",
    "def get_TC_python_cpp_data():\n",
    "    with open('../data/detok-tc-test-data/python.json', 'r') as f:\n",
    "        pycodes = json.load(f)\n",
    "\n",
    "    with open('../data/detok-tc-test-data/cpp.json', 'r') as f:\n",
    "        cppcodes = json.load(f)\n",
    "        \n",
    "    return pycodes, cppcodes\n",
    "\n",
    "\n",
    "def get_java_csharp_data():\n",
    "    \n",
    "    with open('../data/code-translation/java-C#/data/train.java-cs.txt.java', 'r') as f:\n",
    "        javacodes = {i: line for i, line in enumerate(f.readlines())}\n",
    "        \n",
    "    with open('../data/code-translation/java-C#/data/train.java-cs.txt.cs', 'r') as f:\n",
    "        cscodes = {i: line for i, line in enumerate(f.readlines())}\n",
    "        \n",
    "    return javacodes, cscodes\n",
    "\n",
    "\n",
    "def get_data(data1, data2):\n",
    "    if data1 == 'java' and data2 == 'python':\n",
    "        return get_TC_java_py_data()\n",
    "    \n",
    "    elif data1 == 'java' and data2 == 'csharp':\n",
    "        return get_java_csharp_data()\n",
    "    \n",
    "    elif data1 == 'java' and data2 == 'cpp':\n",
    "        return get_TC_java_cpp_data()\n",
    "    \n",
    "    elif data1 == 'python' and data2 == 'cpp':\n",
    "        return get_TC_python_cpp_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d98d03b3-f2db-4958-b09d-c500e10f8fc6",
   "metadata": {},
   "outputs": [],
   "source": [
    "code1, code2 = get_data(lang1, lang2)\n",
    "\n",
    "code1_keys = set(code1.keys())\n",
    "code2_keys = set(code2.keys())\n",
    "\n",
    "assert len(code1_keys.difference(code2_keys)) == 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2ff50367-f60a-4cf6-81b7-4130c0f7701b",
   "metadata": {},
   "outputs": [],
   "source": [
    "order = sorted(code1_keys)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "cb4a70e3-654c-4bab-87f8-88fcc82329c6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using tokenizer <function tokenize_bert at 0x7fadf12c0940>\n"
     ]
    }
   ],
   "source": [
    "tokenize = tokenize_simple if toktype == 'simple' else tokenize_bert\n",
    "\n",
    "print(f'Using tokenizer {tokenize}')\n",
    "\n",
    "code1_tokenized_corpus = [tokenize(code1[key]) for key in order]\n",
    "code2_tokenized_corpus = [tokenize(code2[key]) for key in order]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "9e76f713-a8c8-4f2c-9459-769545513632",
   "metadata": {},
   "outputs": [],
   "source": [
    "bm25 = BM25Okapi(code1_tokenized_corpus)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "4dede2ad-3efd-40e3-b11f-a81337055e3c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c0b217a5307441568511169f740fcba2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/10300 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-9-1b78f3408422>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      8\u001b[0m         \u001b[0mcode2_list\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcode2\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      9\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m         \u001b[0mscores\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbm25\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_scores\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     11\u001b[0m         \u001b[0mmatching_idx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0margmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mscores\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/wmd/lib/python3.8/site-packages/rank_bm25.py\u001b[0m in \u001b[0;36mget_scores\u001b[0;34m(self, query)\u001b[0m\n\u001b[1;32m    114\u001b[0m         \u001b[0mdoc_len\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdoc_len\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    115\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mq\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mquery\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 116\u001b[0;31m             \u001b[0mq_freq\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdoc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mq\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mdoc\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdoc_freqs\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    117\u001b[0m             score += (self.idf.get(q) or 0) * (q_freq * (self.k1 + 1) /\n\u001b[1;32m    118\u001b[0m                                                (q_freq + self.k1 * (1 - self.b + self.b * doc_len / self.avgdl)))\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "corr, total = 0, 0\n",
    "code1_list, code2_list = [], []\n",
    "\n",
    "with tqdm(enumerate(code2_tokenized_corpus), total=len(order)) as pbar:\n",
    "    for i, code in pbar:\n",
    "        \n",
    "        key = order[i]\n",
    "        code2_list.append(code2[key])\n",
    "        \n",
    "        scores = bm25.get_scores(code)\n",
    "        matching_idx = np.argmax(scores)\n",
    "        \n",
    "        key = order[matching_idx]\n",
    "        code1_list.append(code1[key])\n",
    "\n",
    "        if i == matching_idx:\n",
    "            corr += 1\n",
    "        total += 1\n",
    "        \n",
    "        acc = (corr / float(total)) * 100.0\n",
    "        pbar.set_description(f'Accuracy: {acc:0.3f}')\n",
    "    \n",
    "acc = corr / float(total)\n",
    "print(f'Accuracy: {acc * 100.0}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b506de3b-2a46-42a2-99d9-65b5f57d9d3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "corr, total"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "fecc6f39-a11f-4210-81eb-135225242800",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(os.path.join(result_folder, f'{lang1}.txt'), 'w') as f:\n",
    "    f.writelines(code1_list)\n",
    "    \n",
    "with open(os.path.join(result_folder, f'{lang2}.txt'), 'w') as f:\n",
    "    f.writelines(code2_list)\n",
    "    \n",
    "with open(os.path.join(result_folder, f'acc.txt'), 'w') as f:\n",
    "    f.write(f'Accuracy: {acc * 100.0}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7fa86be-29e4-4d67-910e-0eed36807361",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
