import torch
import os
import argparse
from tqdm import tqdm

def get_tomita_vocab():
    vocab = {
                "<PAD>": 0,
                "0": 1,
                "1": 2,
                "<START>": 3,
                "<END>": 4
            }
    return vocab

def format_partial_targets(strings, max_len):
    fstrings = []
    strlen = []
    for string in strings:
        fstr = [int(ch) for ch in string]
        # <START> in input corrosponds to empty string which is always accepted
        # for <END> input last preciction is copied 
        try:
            fstr = [1] + fstr + [fstr[-1]]
        except:
            print(fstr)
            exit()
        fstr_len = len(fstr)
        
        strlen.append(fstr_len)
        # <PAD> predictions are all 0 
        fstr += [0] * (max_len - fstr_len)

        fstrings.append(fstr)

    fstrings = torch.tensor(fstrings, dtype=torch.long)
    strlen = torch.tensor(strlen, dtype=torch.long)

    return fstrings, strlen

def format_strings(strings, max_len, vocab):
    fstrings = []
    strlen = []
    for string in strings:
        fstr = [vocab[ch] for ch in string]
        fstr = [vocab["<START>"]] + fstr + [vocab["<END>"]]
        fstr_len = len(fstr)
        
        strlen.append(fstr_len)

        fstr += [vocab["<PAD>"]] * (max_len - fstr_len)

        fstrings.append(fstr)

    fstrings = torch.tensor(fstrings, dtype=torch.long)
    strlen = torch.tensor(strlen, dtype=torch.long)

    return fstrings, strlen
    

def get_params():
    """
    python preprocess.py --infile --outfile --max_len
    """
    parser = argparse.ArgumentParser(description='tomita_lstm')
    parser.add_argument('--infile', type=str, required=True)
    parser.add_argument('--outfile', type=str, required=True)
    parser.add_argument('--max_len', type=int, required=True)
    parser.add_argument('--partial_target', action="store_true")

    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = get_params()  
    with open(args.infile, "r+") as ff:
        strings = ff.read().splitlines()

    vocab = get_tomita_vocab()
    if args.partial_target:
        data, lengths = format_partial_targets(strings, args.max_len)
    else:
        data, lengths = format_strings(strings, args.max_len, vocab)
    torch.save({"strings": data, "lengths": lengths}, args.outfile)
