{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "683d911d-4676-4aca-81a8-978dab710951",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/mayank/anaconda3/envs/wmd/lib/python3.8/site-packages/gensim/similarities/__init__.py:15: UserWarning: The gensim.similarities.levenshtein submodule is disabled, because the optional Levenshtein package <https://pypi.org/project/python-Levenshtein/> is unavailable. Install Levenhstein (e.g. `pip install python-Levenshtein`) to suppress this warning.\n",
      "  warnings.warn(msg)\n"
     ]
    }
   ],
   "source": [
    "from gensim import corpora, models, similarities\n",
    "\n",
    "import numpy as np\n",
    "import json\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "from transformers import AutoTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "48e272d6-af4f-460a-a351-ffa21717f120",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained('microsoft/graphcodebert-base')\n",
    "\n",
    "lang1 = 'java'\n",
    "lang2 = 'csharp'\n",
    "toktype = 'bert'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "be760f65-f104-473a-ab88-43d5e14ec6ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "def tokenize_bert(code):\n",
    "    RM = ['Ċ', 'Ġ']\n",
    "    tokens = [x for x in tokenizer.tokenize(code) if x not in RM]\n",
    "    return tokens\n",
    "\n",
    "def tokenize_simple(code):\n",
    "    return [x for x in code.lower().split(' ') if len(x) > 0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f0883d5f-89d6-4b87-b7af-d808f27d9fed",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_TC_java_py_data():\n",
    "    with open('../data/detok-tc-test-data/java.json', 'r') as f:\n",
    "        javacodes = json.load(f)\n",
    "\n",
    "    with open('../data/detok-tc-test-data/python.json', 'r') as f:\n",
    "        pycodes = json.load(f)\n",
    "        \n",
    "    return javacodes, pycodes\n",
    "\n",
    "\n",
    "def get_TC_java_cpp_data():\n",
    "    with open('../data/detok-tc-test-data/java.json', 'r') as f:\n",
    "        javacodes = json.load(f)\n",
    "\n",
    "    with open('../data/detok-tc-test-data/cpp.json', 'r') as f:\n",
    "        cppcodes = json.load(f)\n",
    "        \n",
    "    return javacodes, cppcodes\n",
    "\n",
    "\n",
    "def get_TC_python_cpp_data():\n",
    "    with open('../data/detok-tc-test-data/python.json', 'r') as f:\n",
    "        pycodes = json.load(f)\n",
    "\n",
    "    with open('../data/detok-tc-test-data/cpp.json', 'r') as f:\n",
    "        cppcodes = json.load(f)\n",
    "        \n",
    "    return pycodes, cppcodes\n",
    "\n",
    "\n",
    "def get_java_csharp_data():\n",
    "    \n",
    "    with open('../data/code-translation/java-C#/data/train.java-cs.txt.java', 'r') as f:\n",
    "        javacodes = {i: line for i, line in enumerate(f.readlines())}\n",
    "        \n",
    "    with open('../data/code-translation/java-C#/data/train.java-cs.txt.cs', 'r') as f:\n",
    "        cscodes = {i: line for i, line in enumerate(f.readlines())}\n",
    "        \n",
    "    return javacodes, cscodes\n",
    "\n",
    "\n",
    "def get_data(data1, data2):\n",
    "    if data1 == 'java' and data2 == 'python':\n",
    "        return get_TC_java_py_data()\n",
    "    \n",
    "    elif data1 == 'java' and data2 == 'csharp':\n",
    "        return get_java_csharp_data()\n",
    "    \n",
    "    elif data1 == 'java' and data2 == 'cpp':\n",
    "        return get_TC_java_cpp_data()\n",
    "    \n",
    "    elif data1 == 'python' and data2 == 'cpp':\n",
    "        return get_TC_python_cpp_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d98d03b3-f2db-4958-b09d-c500e10f8fc6",
   "metadata": {},
   "outputs": [],
   "source": [
    "code1, code2 = get_data(lang1, lang2)\n",
    "\n",
    "code1_keys = set(code1.keys())\n",
    "code2_keys = set(code2.keys())\n",
    "\n",
    "assert len(code1_keys.difference(code2_keys)) == 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2ff50367-f60a-4cf6-81b7-4130c0f7701b",
   "metadata": {},
   "outputs": [],
   "source": [
    "order = sorted(code1_keys)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "cb4a70e3-654c-4bab-87f8-88fcc82329c6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Token indices sequence length is longer than the specified maximum sequence length for this model (540 > 512). Running this sequence through the model will result in indexing errors\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using tokenizer <function tokenize_bert at 0x7fe701548670>\n"
     ]
    }
   ],
   "source": [
    "tokenize = tokenize_simple if toktype == 'simple' else tokenize_bert\n",
    "\n",
    "print(f'Using tokenizer {tokenize}')\n",
    "\n",
    "code1_tokenized_corpus = [tokenize(code1[key]) for key in order]\n",
    "code2_tokenized_corpus = [tokenize(code2[key]) for key in order]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "11109688-9141-4605-8f23-b2f20035b948",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_text = code1_tokenized_corpus + code2_tokenized_corpus\n",
    "dictionary = corpora.Dictionary(all_text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "ec0eec4c-fcc9-45ad-8c4e-5c6620fa9759",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2265\n"
     ]
    }
   ],
   "source": [
    "feature_cnt = len(dictionary.token2id)\n",
    "print(feature_cnt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "2ac511d9-d17b-488c-8234-95aa3837ab41",
   "metadata": {},
   "outputs": [],
   "source": [
    "corpus = [dictionary.doc2bow(code) for code in code1_tokenized_corpus]\n",
    "# lda = models.ldamodel.LdaModel(corpus, id2word=dictionary) \n",
    "lda = models.ldamulticore.LdaMulticore(corpus, id2word=dictionary)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1a26797-4da1-47db-8f04-88a199d28287",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c863f6766d064acb95541172e3f19925",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1418 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "corr, total = 0, 0\n",
    "\n",
    "\n",
    "with tqdm(enumerate(code2_tokenized_corpus), total=len(order)) as pbar:\n",
    "    for i, code in pbar:\n",
    "        kw_vector = dictionary.doc2bow(code)\n",
    "        index = similarities.MatrixSimilarity(lda[corpus])\n",
    "        sim = index[lda[kw_vector]]\n",
    "        \n",
    "        matching_idx = np.argmax(sim)\n",
    "        assert max(sim) == sim[matching_idx]\n",
    "        \n",
    "        if i == matching_idx:\n",
    "            corr += 1\n",
    "        total += 1\n",
    "        \n",
    "        acc = (corr / float(total)) * 100.0\n",
    "        pbar.set_description(f'Accuracy: {acc:0.3f}')\n",
    "    \n",
    "acc = corr / float(total)\n",
    "print(f'Accuracy: {acc * 100.0}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fecc6f39-a11f-4210-81eb-135225242800",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
