#!/usr/bin/env python
# coding: utf-8

# In[1]:


from rank_bm25 import BM25Okapi

import os, sys
import numpy as np
import json
from tqdm.notebook import tqdm

from transformers import AutoTokenizer


# In[2]:


tokenizer = AutoTokenizer.from_pretrained('microsoft/graphcodebert-base')

lang1 = sys.argv[1]     # 'java'
lang2 = sys.argv[2]     #'csharp'
toktype = sys.argv[3]   # 'bert'
result_folder = sys.argv[4]  #'./java-csharp/bm25'


# In[3]:


def tokenize_bert(code):
    RM = ['Ċ', 'Ġ']
    tokens = [x for x in tokenizer.tokenize(code) if x not in RM]
    return tokens

def tokenize_simple(code):
    return [x for x in code.lower().split(' ') if len(x) > 0]


# In[4]:


def get_TC_java_py_data():
    with open('../data/detok-tc-test-data/java.json', 'r') as f:
        javacodes = json.load(f)

    with open('../data/detok-tc-test-data/python.json', 'r') as f:
        pycodes = json.load(f)
        
    return javacodes, pycodes


def get_TC_java_cpp_data():
    with open('../data/detok-tc-test-data/java.json', 'r') as f:
        javacodes = json.load(f)

    with open('../data/detok-tc-test-data/cpp.json', 'r') as f:
        cppcodes = json.load(f)
        
    return javacodes, cppcodes


def get_TC_python_cpp_data():
    with open('../data/detok-tc-test-data/python.json', 'r') as f:
        pycodes = json.load(f)

    with open('../data/detok-tc-test-data/cpp.json', 'r') as f:
        cppcodes = json.load(f)
        
    return pycodes, cppcodes


def get_java_csharp_data():
    
    with open('../data/code-translation/java-C#/data/train.java-cs.txt.java', 'r') as f:
        javacodes = {i: line for i, line in enumerate(f.readlines())}
        
    with open('../data/code-translation/java-C#/data/train.java-cs.txt.cs', 'r') as f:
        cscodes = {i: line for i, line in enumerate(f.readlines())}
        
    return javacodes, cscodes


def get_data(data1, data2):
    if data1 == 'java' and data2 == 'python':
        return get_TC_java_py_data()
    
    elif data1 == 'java' and data2 == 'csharp':
        return get_java_csharp_data()
    
    elif data1 == 'java' and data2 == 'cpp':
        return get_TC_java_cpp_data()
    
    elif data1 == 'python' and data2 == 'cpp':
        return get_TC_python_cpp_data()


# In[5]:


code1, code2 = get_data(lang1, lang2)

code1_keys = set(code1.keys())
code2_keys = set(code2.keys())

assert len(code1_keys.difference(code2_keys)) == 0


# In[6]:


order = sorted(code1_keys)


# In[7]:


tokenize = tokenize_simple if toktype == 'simple' else tokenize_bert

print(f'Using tokenizer {tokenize}')

code1_tokenized_corpus = [tokenize(code1[key]) for key in order]
code2_tokenized_corpus = [tokenize(code2[key]) for key in order]


# In[8]:


bm25 = BM25Okapi(code1_tokenized_corpus)


# In[9]:


corr, total = 0, 0
code1_list, code2_list = [], []

with tqdm(enumerate(code2_tokenized_corpus), total=len(order)) as pbar:
    for i, code in pbar:
        
        key = order[i]
        code2_list.append(code2[key])
        
        scores = bm25.get_scores(code)
        matching_idx = np.argmax(scores)
        
        key = order[matching_idx]
        code1_list.append(code1[key])

        if i == matching_idx:
            corr += 1
        total += 1
        
        acc = (corr / float(total)) * 100.0
        pbar.set_description(f'Accuracy: {acc:0.3f}')
    
acc = corr / float(total)
print(f'Accuracy: {acc * 100.0}')


# In[ ]:


corr, total


# In[10]:


with open(os.path.join(result_folder, f'{lang1}.txt'), 'w') as f:
    f.writelines(code1_list)
    
with open(os.path.join(result_folder, f'{lang2}.txt'), 'w') as f:
    f.writelines(code2_list)
    
with open(os.path.join(result_folder, f'acc.txt'), 'w') as f:
    f.write(f'Accuracy: {acc * 100.0}')


# In[ ]:




