from dyck_generator import DyckLanguage, all_pairs, all_letters
import argparse
import json
import random


def get_params():
    parser = argparse.ArgumentParser(description='dyck_generator')
    parser.add_argument('--num_par', type = int, default = 2)
    parser.add_argument('--min_size', type = int, default = 2)
    parser.add_argument('--max_size', type = int, default = 50)
    parser.add_argument('--generate_negative', action="store_true")
    parser.add_argument('--total_samples', type = int, default = 5000)
    parser.add_argument('--file_path', type = str, default = "../data/dyck.json")
    parser.add_argument('--vocab_path', type = str, default = None)
    args = parser.parse_args()
    return args


if __name__ == "__main__":
    ## Parameters of the Probabilistic Dyck Language 
    args = get_params()
    NUM_PAR = args.num_par
    MIN_SIZE = args.min_size
    MAX_SIZE = args.max_size
    P_VAL = 0.5
    Q_VAL = 0.25

    # Number of samples in the training corpus
    SIZE = args.total_samples

    # Create a Dyck language generator
    Dyck = DyckLanguage (NUM_PAR, P_VAL, Q_VAL)
    all_letters = word_set = Dyck.return_vocab ()
    n_letters = vocab_size = len (word_set)

    print('Loading data...')

    if args.generate_negative:
        input, output, st = Dyck.negative_set_generator (SIZE, MIN_SIZE, MAX_SIZE)
    else:
        input, output, st = Dyck.training_set_generator (SIZE, MIN_SIZE, MAX_SIZE)
    #test_input, test_output, st2 = Dyck.training_set_generator (TEST_SIZE, MAX_SIZE + 2, 2 * MAX_SIZE)
    data = []
    for idx in range(len(input)):
        doc = {
            "inp" : input[idx],
            "out" : output[idx],
            "len" : len(input[idx])
        }
        data.append(doc)

    with open(args.file_path, "w+") as ff:
        json.dump(data, ff, indent = 4)

    if args.vocab_path is not None:
        vocab = {
            "<PAD>" : 0,
            "<START>" : 1,
            "<END>" : 2,
            "<UNK>" : 3
        }
        vocab.update({ sym: i+4 for i, sym in enumerate(Dyck.vocabulary)})

        with open(args.vocab_path, "w+") as ff:
            json.dump(vocab, ff, indent = 4)

        