#!/usr/bin/env python
# coding: utf-8

# In[1]:


from rank_bm25 import BM25Okapi

import os, sys
import numpy as np
import json
from tqdm import tqdm

from transformers import AutoTokenizer


# In[2]:


tokenizer = AutoTokenizer.from_pretrained('microsoft/graphcodebert-base')

lang1 = sys.argv[1]     # 'java'
lang2 = sys.argv[2]     #'csharp'
toktype = sys.argv[3]   # 'bert'
result_folder = sys.argv[4]  #'./java-csharp/bm25'


# In[3]:


def tokenize_bert(code):
    RM = ['Ċ', 'Ġ']
    tokens = [x for x in tokenizer.tokenize(code) if x not in RM]
    return tokens

def tokenize_simple(code):
    return [x for x in code.lower().split(' ') if len(x) > 0]


# In[4]:


def get_TC_java_py_data():
    with open('../data/detok-tc-test-data/java.json', 'r') as f:
        javacodes = json.load(f)

    with open('../data/detok-tc-test-data/python.json', 'r') as f:
        pycodes = json.load(f)
        
    return javacodes, pycodes


def get_TC_java_cpp_data():
    with open('../data/detok-tc-test-data/java.json', 'r') as f:
        javacodes = json.load(f)

    with open('../data/detok-tc-test-data/cpp.json', 'r') as f:
        cppcodes = json.load(f)
        
    return javacodes, cppcodes


def get_TC_python_cpp_data():
    with open('../data/detok-tc-test-data/python.json', 'r') as f:
        pycodes = json.load(f)

    with open('../data/detok-tc-test-data/cpp.json', 'r') as f:
        cppcodes = json.load(f)
        
    return pycodes, cppcodes


def get_java_csharp_data():
    
    with open('../data/code-translation/java-C#/data/train.java-cs.txt.java', 'r') as f:
        javacodes = {i: line for i, line in enumerate(f.readlines())}
        
    with open('../data/code-translation/java-C#/data/train.java-cs.txt.cs', 'r') as f:
        cscodes = {i: line for i, line in enumerate(f.readlines())}
        
    return javacodes, cscodes


def get_codenet_data(lang1, lang2):

    basepath = '/dccstor/mayankag1/wmd-code/codenet-genai/matchexp_data'
    # basepath = '/Users/mayank/Documents/projects/wmd-codesim/codenet-genai/matchexp_data'
    lang1_path = os.path.join(basepath, lang1)
    lang2_path = os.path.join(basepath, lang2)

    lang1_codedict, lang2_codedict = {}, {}
    metadata = {'nbow-data-1': lang1, 'nbow-data-2': lang2}

    # Read language1 code segments
    for problem in os.listdir(lang1_path):
        if not problem.startswith('p'):
            continue

        for submission_fname in os.listdir(os.path.join(lang1_path, problem)):
            if not submission_fname.startswith('s'):
                continue    

            with open(os.path.join(lang1_path, problem, submission_fname), 'r') as f:
                code = f.read()
            
            lang1_codedict[f'{problem}-{submission_fname}'] = code
    
    # Read language2 code segments
    for problem in os.listdir(lang2_path):
        if not problem.startswith('p'):
            continue

        for submission_fname in os.listdir(os.path.join(lang2_path, problem)):
            if not submission_fname.startswith('s'):
                continue
            
            with open(os.path.join(lang2_path, problem, submission_fname), 'r') as f:
                code = f.read()
            
            lang2_codedict[f'{problem}-{submission_fname}'] = code

    return lang1_codedict, lang2_codedict


def get_data(data1, data2):
    if data1 == 'java' and data2 == 'python':
        return get_TC_java_py_data()
    
    elif data1 == 'java' and data2 == 'csharp':
        return get_java_csharp_data()
    
    elif data1 == 'java' and data2 == 'cpp':
        return get_TC_java_cpp_data()
    
    elif data1 == 'python' and data2 == 'cpp':
        return get_TC_python_cpp_data()

    elif data1.startswith('CodeNet') and data2.startswith('CodeNet'):
        lang1 = data1.split('-')[1]
        lang2 = data2.split('-')[1]
        return get_codenet_data(lang1, lang2)


def get_codenet_accuracy(key1, key2):
    prob1 = key1.split('-')[0]
    prob2 = key2.split('-')[0]
    return prob1 == prob2


# In[5]:


code1, code2 = get_data(lang1, lang2)

code1_keys = set(code1.keys())
code2_keys = set(code2.keys())

# assert len(code1_keys.difference(code2_keys)) == 0


# In[6]:


order_code1 = sorted(code1_keys)
order_code2 = sorted(code2_keys)


# In[7]:


tokenize = tokenize_simple if toktype == 'simple' else tokenize_bert

print(f'Using tokenizer {tokenize}')

code1_tokenized_corpus = [tokenize(code1[key]) for key in order_code1]
code2_tokenized_corpus = [tokenize(code2[key]) for key in order_code2]


# In[8]:


bm25 = BM25Okapi(code1_tokenized_corpus)


# In[9]:


corr, total = 0, 0
code1_list, code2_list = [], []

with tqdm(enumerate(code2_tokenized_corpus), total=len(order_code2)) as pbar:
    for i, code in pbar:
        
        key_code2 = order_code2[i]
        # code2_list.append(code2[key])
        
        scores = bm25.get_scores(code)
        matching_idx = np.argmax(scores)
        
        key_code1 = order_code1[matching_idx]
        code1_list.append(code1[key_code1])

        if lang1.startswith('CodeNet') and lang2.startswith('CodeNet'):
            corr += get_codenet_accuracy(key_code1, key_code2)
        elif i == matching_idx:
                corr += 1
        total += 1
        
        acc = (corr / float(total)) * 100.0
        pbar.set_description(f'Accuracy: {acc:0.3f}')
    
acc = corr / float(total)
print(f'Accuracy: {acc * 100.0}')


# In[ ]:


corr, total


# In[10]:


with open(os.path.join(result_folder, f'{lang1}.txt'), 'w') as f:
    f.writelines(code1_list)
    
with open(os.path.join(result_folder, f'{lang2}.txt'), 'w') as f:
    f.writelines(code2_list)
    
with open(os.path.join(result_folder, f'acc.txt'), 'w') as f:
    f.write(f'Accuracy: {acc * 100.0}')


# In[ ]:




