import argparse
import json
import numpy as np


def preprocess_input(ss, vocab, max_len):
    p = [vocab["<START>"]] + [vocab[ch] if ch in vocab else vocab["<UNK>"] for ch in ss] + [vocab["<END>"]]
    if len(p) > max_len:
        raise Exception("sequence of len {} detected, max len is {}".format(len(p), max_len))
    
    if (ll := len(p)) < max_len:
        p = p + [vocab["<PAD>"]] * (max_len - len(p))
    
    assert len(p) == max_len
    return p, ll

def preprocess_stack(ss, max_len):
    ss = "0" + ss + "0"
    if (slen:=len(ss)) < max_len:
        ss += "0"*(max_len - slen)
    ss = [ int(idx) for idx in ss ]
    assert len(ss) == max_len
    return ss



def get_params():
    parser = argparse.ArgumentParser(description='dyck_generator')
    parser.add_argument('--max_len', type = int, default = 52)
    parser.add_argument('--inp_fpath', type = str, default = "../data/dyck.json")
    parser.add_argument('--out_fpath', type = str, default = "../data/dyck.npz")
    parser.add_argument('--vocab_fpath', type = str, default = None)
    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = get_params()

    with open(args.inp_fpath, "r+") as ff:
        data = json.load(ff)

    with open(args.vocab_fpath, "r+") as ff:
        vocab = json.load(ff)

    inp, ll = list(zip(*[ preprocess_input(d["inp"], vocab, args.max_len) for d in data ]))
    inp, ll = np.array(list(inp)), np.array(list(ll))
    #stack = np.array([ preprocess_stack(d["out"], args.max_len) for d in data ])
    stack = np.zeros([inp.shape[0], inp.shape[1]]) # we are not using stack now anyways
    print("inp : ", inp.shape, "stack  : ", stack.shape)
    np.savez(args.out_fpath, inp = inp, inp_len = ll, stack = stack)

    

    

    
