#!/usr/bin/env python
# coding: utf-8

# In[1]:


from transformers import AutoTokenizer, AutoModel
import torch
import numpy as np
import json
from pyemd import emd
from tqdm import tqdm
from multiprocessing import Pool

from sklearn.metrics import euclidean_distances
from sklearn.preprocessing import normalize

from wmd import WordMoversDistance


# In[2]:


tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")
model = AutoModel.from_pretrained("microsoft/codebert-base")
embeddings = model.embeddings.word_embeddings.weight.detach().numpy()


# In[3]:


wmd = WordMoversDistance(embeddings, n_jobs=-1, verbose=0)


# In[4]:


def tokenize(code):
    rm_tokens = ['Ċ', 'Ġ']
    tokens = [x for x in tokenizer.tokenize(code) if x not in rm_tokens]
    token_ids = tokenizer.convert_tokens_to_ids(tokens)
    return token_ids


# In[5]:


with open('../data/detok-tc-test-data/java.json', 'r') as f:
    data_java = json.load(f)
    
with open('../data/detok-tc-test-data/python.json', 'r') as f:
    data_py = json.load(f)


# In[6]:


ids_py, code_py = zip(*data_py.items())
ids_java, code_java = zip(*data_java.items())


# In[7]:


tokens_py = [tokenize(code) for code in code_py]
tokens_java = [tokenize(code) for code in code_java]


# In[8]:


len(tokens_py), len(tokens_java)


# In[9]:


wmd.fit(ids_java, tokens_java)


# In[10]:


result = {}

for i in tqdm(range(len(ids_py))):
    id = ids_py[i]
    toks = tokens_py[i]
    
    dists = wmd.predict(toks)
    coderesult = {
        javaid: dist for (javaid, dist) in zip(wmd.source_ids, dists)
    }
    
    result[id] = coderesult


# In[ ]:


with open('../results/TC-test-set/python-java-codebert-parallel.json', 'w') as f:
    json.dump(result, f)


# In[ ]:


corr, tot = 0, 0

for k, v in result.items():
    java_sorted = sorted(v.items(), key=lambda item: item[1])[0]
    
    if k == java_sorted[0]:
        corr += 1
    tot += 1

print(corr/float(tot), corr, tot)


# In[ ]:





# In[ ]:





# In[ ]:




