#!/usr/bin/env python
# coding: utf-8

# ## Get a list of problem sets from matched dataset

# In[8]:


import os
import json
import pandas as pd
import numpy as np
import pickle

from tqdm.notebook import tqdm
from transformers import AutoTokenizer


# In[2]:


datapath = "./data/codenet-jsonl-processed"
outdir = "./random-data/codenet-jsonl-processed"
codenetdir = "../Project_CodeNet"

os.makedirs(outdir, exist_ok=True)


# In[3]:


langpairs = [x for x in os.listdir(datapath) if not x.startswith('.')]
len(langpairs)


# In[4]:


tokenizer = AutoTokenizer.from_pretrained("microsoft/graphcodebert-base")

def tokenize(code):
    RM = ['Ċ', 'Ġ']
    tokens = [x for x in tokenizer.tokenize(code) if x not in RM]
    return len(tokens)


# In[5]:


probdict = {}

for langs in tqdm(langpairs):
    train_fpath = os.path.join(datapath, langs, 'train.jsonl')
    val_fpath = os.path.join(datapath, langs, 'val.jsonl')
    
    cnt = 0
    probdict[langs] = {'probid':set(), 'n': None}
    
    with open(train_fpath, 'r') as f:
        for line in f:
            line = json.loads(line)
            probdict[langs]['probid'].add(line['prob'])
            cnt += 1
        
        probdict[langs]['n'] = cnt

    with open(val_fpath, 'r') as f:
        for line in f:
            line = json.loads(line)
            probdict[langs]['probid'].add(line['prob'])
            cnt += 1
        
        probdict[langs]['n'] = cnt


# ## Read codenet

# In[6]:


def check_length(row):
    
    THRES = 512
    prob = row['problem_id']
    lang = row['language']
    subid = row['submission_id']
    ext = row['filename_ext']
    
    path= os.path.join(codenetdir, 'data', prob, lang, f'{subid}.{ext}')
    with open(path, 'r') as f:
        code = f.read().strip()
    
    return tokenize(code) <= THRES

def read_codenet_metadata(problem, lang1, lang2):
    
    lang1_subs, lang2_subs = [], []
    fpath = os.path.join(codenetdir, 'metadata', f'{problem}.csv')
    metadata = pd.read_csv(fpath)
    
    metadata = metadata[ (metadata['language'].isin([lang1, lang2])) & (metadata['status'] == 'Accepted') ]
    
    for i, row in metadata.iterrows():
        
        if not check_length(row):
            continue
        
        if row['language'] == lang1:
            lang1_subs.append(row['submission_id'])
        else:
            lang2_subs.append(row['submission_id'])
    
    lang1_subs = np.random.permutation(lang1_subs)
    lang2_subs = np.random.permutation(lang2_subs)
    n = min(len(lang1_subs), len(lang2_subs))
    
    lang1_subs = list(lang1_subs[:n])
    lang2_subs = list(lang2_subs[:n])
    
    return lang1_subs, lang2_subs


# In[ ]:


try:
    with open(os.path.join(outdir, 'codepairs.pkl'), 'rb') as f:
        codepairs = pickle.load(f)
except Exception as err:
    print(err)
    codepairs = {}

for langs, vals in tqdm(probdict.items()):
    
    if langs in codepairs:
        print(f'{langs} in codepairs. skipping...')
        continue
    
    lang_codepairs = []
    lang1, lang2 = langs.split('-')
    probs = vals['probid']
    n = vals['n']
    
    allsubs_lang1, allsubs_lang2 = [], []
    
    for prob in tqdm(probs, desc="Problems", leave=False):
        lang1_subs, lang2_subs = read_codenet_metadata(prob, lang1, lang2)
        allsubs_lang1.append(lang1_subs)
        allsubs_lang2.append(lang2_subs)
        
    
    idx = 0
    while True:
        for i, prob in enumerate(probs):
            if len(allsubs_lang1[i]) <= idx or len(allsubs_lang2[i]) <= idx:
                continue
            
            if len(lang_codepairs) >= n:
                break
            
            lang_codepairs.append( (prob, allsubs_lang1[i][idx], allsubs_lang2[i][idx]) )
        
        idx += 1
        
        if len(lang_codepairs) >= n:
            break
    
    codepairs[langs] = lang_codepairs
    
    
    with open(os.path.join(outdir, 'codepairs.pkl'), 'wb') as f:
        pickle.dump(codepairs, f)


# In[ ]:




