import argparse
import tomita_generator as tomita
from tqdm import tqdm

def get_params():
    parser = argparse.ArgumentParser(description='tomita_lstm')
    parser.add_argument('--grammar', type=int, required=True)
    parser.add_argument('--src_file', type=str, required=True)
    parser.add_argument('--target_file', type=str, required=True)

    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = get_params()
    fpath = args.src_file
    opath = args.target_file

    with open(fpath, "r+") as ff:
        data = ff.readlines()

    data = [s.replace("\n", "") for s in data]

    klass_name = "Tomita{}Language".format(args.grammar)
    klass = getattr(tomita, klass_name)
    grammar = klass(0.4, 0.4)

    labels_data = []
    for string in tqdm(data):
        labels = []
        for idx in range(1, len(string)+1):
            partial_string = string[:idx]
           
            if grammar.belongs_to_lang(partial_string):
                labels.append("1")
            else:
                labels.append("0")

        assert len(labels) == len(string)
        labels_data.append("".join(labels))

    labels = "\n".join(labels_data)
    with open(opath, "w+") as ff:
        ff.write(labels)


