import argparse
from multiprocessing import Pool
import os
import shutil


parser = argparse.ArgumentParser()
parser.add_argument('--data-dir', type=str)
args = parser.parse_args()
data_dir = args.data_dir

def encode_subword(subwords, subwords_dict):
    subwords_mask = ['0']*len(subwords)
    if len(subwords) <= 1:
        return subwords
    l = len(subwords)
    tmp = []
    for i, sb in enumerate(subwords):
        tmp.append("{}@@{}@@{}".format(sb, l, i))
    return tmp


def debpe_sentence(s):
    token_li = s.split()
    #print(token_li)
    debpe_token_li = []
    subword_token_li = []
    tmp = []
    subwords_dict = {}
    for t in token_li:
        if len(t) > 2 and t[-2:] == "@@":
            tmp.append(t[:-2])
        else:
            tmp.append(t)
            tmp = ["".join(tmp)]*(len(tmp))
            if len(tmp) > 1:
                tmp = encode_subword(tmp, subwords_dict)
            debpe_token_li += tmp
            #print(tmp)
            tmp=[]
    assert(len(token_li) == len(debpe_token_li))
    debpe_s = ' '.join(debpe_token_li) + '\n'
    return debpe_s


def main(data_dir):
    tasks=['QQP', 'MNLI', 'QNLI', 'MRPC', 'RTE', 'STS-B', 'SST-2', 'CoLA']
    #tasks=['QQP']
    for c in tasks:
        # debpe corpus...
        print("debpe corpus...")
        task_dir = data_dir + '/' + c + '/processed'
        print(task_dir)
        directory = os.listdir(task_dir)
        #shutil.rmtree(task_dir + os.sep +'debpe')
        #os.mkdir(task_dir + os.sep +'debpe')
        for fname in directory:
            if '.bpe' in fname:
                with open(task_dir + os.sep + fname, 'r') as f:
                    sents = f.readlines()
                    pool = Pool(processes=16)
                    res_li = pool.map_async(debpe_sentence, sents).get()
                    pool.close()
                    pool.join() 
                    #print(res_li)
                    debpe_sents = res_li

                    assert len(sents) == len(debpe_sents)
                    with open(task_dir + os.sep + fname.replace('.bpe','') + '.debpe', "w") as f1:
                        f1.writelines(debpe_sents)
                    print("Processed {} lines for {}.".format(len(sents), c))


        
        
        

if __name__ == "__main__":
    main(data_dir)