#!/usr/bin/env python
# coding: utf-8

# In[1]:


from gensim import corpora, models, similarities

import os, sys
import numpy as np
import json
from tqdm import tqdm
import sys

from transformers import AutoTokenizer


# In[2]:


tokenizer = AutoTokenizer.from_pretrained('microsoft/graphcodebert-base')

lang1 = sys.argv[1]
lang2 = sys.argv[2]
toktype = sys.argv[3]
result_folder = sys.argv[4]


# In[3]:


def tokenize_bert(code):
    RM = ['Ċ', 'Ġ']
    tokens = [x for x in tokenizer.tokenize(code) if x not in RM]
    return tokens

def tokenize_simple(code):
    return [x for x in code.lower().split(' ') if len(x) > 0]


# In[4]:


def get_TC_java_py_data():
    with open('../data/detok-tc-test-data/java.json', 'r') as f:
        javacodes = json.load(f)

    with open('../data/detok-tc-test-data/python.json', 'r') as f:
        pycodes = json.load(f)
        
    return javacodes, pycodes


def get_TC_java_cpp_data():
    with open('../data/detok-tc-test-data/java.json', 'r') as f:
        javacodes = json.load(f)

    with open('../data/detok-tc-test-data/cpp.json', 'r') as f:
        cppcodes = json.load(f)
        
    return javacodes, cppcodes


def get_TC_python_cpp_data():
    with open('../data/detok-tc-test-data/python.json', 'r') as f:
        pycodes = json.load(f)

    with open('../data/detok-tc-test-data/cpp.json', 'r') as f:
        cppcodes = json.load(f)
        
    return pycodes, cppcodes


def get_java_csharp_data():
    
    with open('../data/code-translation/java-C#/data/train.java-cs.txt.java', 'r') as f:
        javacodes = {i: line for i, line in enumerate(f.readlines())}
        
    with open('../data/code-translation/java-C#/data/train.java-cs.txt.cs', 'r') as f:
        cscodes = {i: line for i, line in enumerate(f.readlines())}
        
    return javacodes, cscodes


def get_data(data1, data2):
    if data1 == 'java' and data2 == 'python':
        return get_TC_java_py_data()
    
    elif data1 == 'java' and data2 == 'csharp':
        return get_java_csharp_data()
    
    elif data1 == 'java' and data2 == 'cpp':
        return get_TC_java_cpp_data()
    
    elif data1 == 'python' and data2 == 'cpp':
        return get_TC_python_cpp_data()


# In[5]:


code1, code2 = get_data(lang1, lang2)

code1_keys = set(code1.keys())
code2_keys = set(code2.keys())

assert len(code1_keys.difference(code2_keys)) == 0


# In[6]:


order = sorted(code1_keys)


# In[7]:


tokenize = tokenize_simple if toktype == 'simple' else tokenize_bert

print(f'Using tokenizer {tokenize}')

code1_tokenized_corpus = [tokenize(code1[key]) for key in order]
code2_tokenized_corpus = [tokenize(code2[key]) for key in order]


# In[8]:


all_text = code1_tokenized_corpus + code2_tokenized_corpus
dictionary = corpora.Dictionary(all_text)


# In[9]:


feature_cnt = len(dictionary.token2id)
print(feature_cnt)


# In[10]:


corpus = [dictionary.doc2bow(code) for code in code1_tokenized_corpus]
lda = models.ldamodel.LdaModel(corpus, id2word=dictionary) 


# In[11]:


corr, total = 0, 0
code1_list, code2_list = [], []


with tqdm(enumerate(code2_tokenized_corpus), total=len(order)) as pbar:
    for i, code in pbar:
        
        key = order[i]
        code2_list.append(code2[key])
        
        kw_vector = dictionary.doc2bow(code)
        index = similarities.MatrixSimilarity(lda[corpus])
        sim = index[lda[kw_vector]]
        
        matching_idx = np.argmax(sim)
        assert max(sim) == sim[matching_idx]
        
        key = order[matching_idx]
        code1_list.append(code1[key])
        
        if i == matching_idx:
            corr += 1
        total += 1
        
        acc = (corr / float(total)) * 100.0
        pbar.set_description(f'Accuracy: {acc:0.3f}')
    
acc = corr / float(total)
print(f'Accuracy: {acc * 100.0}')


# In[ ]:


with open(os.path.join(result_folder, f'{lang1}.txt'), 'w') as f:
    f.writelines(code1_list)
    
with open(os.path.join(result_folder, f'{lang2}.txt'), 'w') as f:
    f.writelines(code2_list)

with open(os.path.join(result_folder, f'acc.txt'), 'w') as f:
    f.write(f'Accuracy: {acc * 100.0}')



