from tqdm import tqdm
import joblib
import numpy as np 

from symbolic_utils import * 


def num_tokens_check(text):
    input_text, target_text = text.split(':')
    ni, nt = len(input_text.split()), len(target_text.split())
    return ni+nt, ni, nt

def postprocessing(data_path, _postprocessing, extention, encoding='prefix', n_jobs=-1):
    import os 

    for ext in ['test', 'train']:
        filename = data_path + f'.{ext}.{extention}'
        filename_out = data_path + f'.{ext}'

        if not os.path.exists(filename): 
            continue 

        with open(filename, "r") as f:
            data = f.read().splitlines()

        input_texts = [line.split(":")[0].strip() for line in data]
        target_texts = [line.split(":")[1].strip() for line in data]

        texts = joblib.Parallel(n_jobs=n_jobs, backend="loky", verbose=True)(joblib.delayed(_postprocessing)(intext, tartext, encoding=encoding) for intext, tartext in zip(input_texts, target_texts)) 
        with open(filename_out, "w", buffering=4096) as f:
            f.writelines('\n'.join(texts))

        ret = joblib.Parallel(n_jobs=n_jobs, backend="loky", verbose=True)(joblib.delayed(num_tokens_check)(text) for text in texts) 
        num_tokens, num_input_tokens, num_target_tokens = zip(*ret)
        num_tokens = np.array(list(num_tokens))
        num_input_tokens = np.array(list(num_input_tokens))
        num_target_tokens = np.array(list(num_target_tokens))
        n, ni, nt = np.max(num_tokens), np.max(num_input_tokens), np.max(num_target_tokens)

        print(f'-- {ext} set')
        print(f'max #tokens = {n}', flush=True)
        print(f'max #input tokens = {ni}', flush=True)
        print(f'max #target tokens = {nt}', flush=True)
        print('')


def _postprocessing(intext, tartext, encoding='prefix'):
    P_text = [infix_to_prefix(p, ) for p in P]
    Q_text = [sympy_to_text(q, representation='prefix', integer_encoding='none') for q in Q]
    P_text = [sympy_to_text(p, representation='prefix', integer_encoding='none') for p in P]
    Q_text = [sympy_to_text(q, representation='prefix', integer_encoding='none') for q in Q]
    P_text = ' [SEP] '.join(P_text)
    Q_text = ' [SEP] '.join(Q_text)

    return ' : '.join([P_text, Q_text])


if __name__ == '__main__':
    import sys
    
    assert len(sys.argv) == 2
    # data_path = 'data/gb_n=3_r=5_d=3_m=2_Gd=5_Gm=None_F=5'
    data_path = sys.argv[1]
    extension = 'raw'
    encodings = ['prefix']
    term_order = ['lex', 'degrevlex']
    n_jobs    = -1
    postprocessing(data_path, _postprocessing, extention=extension, n_jobs=n_jobs)


