import argparse
from multiprocessing import Pool


parser = argparse.ArgumentParser()
parser.add_argument('--data-dir', type=str)
parser.add_argument('--left-percent', type=float)
parser.add_argument('--right-percent', type=float)
args = parser.parse_args()
data_dir = args.data_dir
left_percent = args.left_percent
right_percent = args.right_percent

# filter words by frequency
with open(data_dir+"/corpus.train.vocab", "r") as f:
    lines = f.readlines()
l = len(lines)
left_idx = int(l*left_percent)
right_idx = int(l*right_percent)
filtered_lines = lines[left_idx:right_idx]
filtered_words_dict = {
    l.split(' ')[0]: l.split(' ')[1]
    for l in filtered_lines
}
print("# of words is {}. Extract words between ({}, {}).".format(
    l, left_idx, right_idx))
    

def encode_subword(subwords, subwords_dict):
    subwords_mask = ['0']*len(subwords)
    if len(subwords) <= 1:
        return subwords
    l = len(subwords)
    tmp = []
    for i, sb in enumerate(subwords):
        tmp.append("{}@@{}@@{}".format(sb, l, i))
    # update subwords dict
    if subwords[0] in filtered_words_dict.keys():
        subwords_mask = ['1']*len(subwords)
        if subwords[0] not in subwords_dict.keys():
            subwords_dict.update({
                subwords[0]: tmp.copy()
            })
    return tmp, subwords_dict, subwords_mask


def debpe_sentence(s):
    token_li = s.split()
    debpe_token_li = []
    subword_token_li = []
    tmp = []
    subwords_dict = {}
    for t in token_li:
        if len(t) > 2 and t[-2:] == "@@":
            tmp.append(t[:-2])
        else:
            tmp.append(t)
            tmp = ["".join(tmp)]*(len(tmp))
            if len(tmp) > 1:
                tmp, subwords_dict, subwords_mask = encode_subword(tmp, subwords_dict)
            else:
                subwords_mask = ['0']
            debpe_token_li += tmp
            subword_token_li += subwords_mask
            tmp.clear()
    assert(len(token_li) == len(debpe_token_li))
    assert(len(token_li) == len(subword_token_li))
    debpe_s = ' '.join(debpe_token_li) + '\n'
    subwords_mask_s = ' '.join(subword_token_li) + '\n'
    return debpe_s, subwords_dict, subwords_mask_s


def main(data_dir):
    splits = ['train', 'valid']
    for c in splits:
        # debpe corpus...
        print("debpe corpus...")
        with open(data_dir+"/corpus.{}.tok.bpe".format(c), "r") as f:
            sents = f.readlines()
        pool = Pool(processes=16)
        res_li = pool.map_async(debpe_sentence, sents).get()
        pool.close()
        pool.join() 
        debpe_sents = [i[0] for i in res_li]
        tmp_subwords_dicts = [i[1] for i in res_li]
        subword_mask_sents = [i[2] for i in res_li]

        # save new dict for tnf dataset...
        subwords_dict = {}
        for d in tmp_subwords_dicts:
            subwords_dict.update(d)
        for k, v in subwords_dict.items():
            assert(k in filtered_words_dict.keys())
            freq = filtered_words_dict[k]
            del filtered_words_dict[k]
            for vv in v:
                filtered_words_dict.update({
                    vv: freq
                })
        filtered_lines = [(item[0], item[1])for item in filtered_words_dict.items()]
        filtered_lines = sorted(filtered_lines, key=lambda x: int(x[1][:-1]), reverse=True)
        filtered_lines = [' '.join(l) for l in filtered_lines]

        assert len(sents) == len(debpe_sents)
        assert len(sents) == len(subword_mask_sents)

        if c == 'train':
            with open(data_dir+"/dict.debpe.filtered", "w") as f:
                f.writelines(filtered_lines)
        with open(data_dir+"/corpus.{}.tok.debpe".format(c), "w") as f:
            f.writelines(debpe_sents)
        with open(data_dir+"/corpus.{}.tok.sbmask".format(c), "w") as f:
            f.writelines(subword_mask_sents)
        print("Processed {} lines for {}.".format(len(sents), c))
        
        

if __name__ == "__main__":
    main(data_dir)