import argparse
import json
import random

def get_params():
    parser = argparse.ArgumentParser(description='dyck_generator')
    parser.add_argument('--file_path', type = str, default = "../data/dyck.json")
    parser.add_argument('--shuffle', action = "store_true")
    parser.add_argument('--test', type = int, default = 1000)
    parser.add_argument('--val', type = int, default=1000)
    args = parser.parse_args()
    return args

if __name__ == "__main__":
    args = get_params()
    with open(args.file_path, "r+") as ff:
        data = json.load(ff)

    if args.shuffle:
        random.shuffle(data)
    
    n_samples = len(data)

    test_size = int(args.test)
    val_size = int(args.val)

    dsplit = {
        "test_data" : data[:test_size],
        "val_data" : data[test_size : (train_idx := test_size + val_size)],
        "train_data" : data[train_idx :]
    }

    for ttype in ["train", "test", "val"]:
        fpath = args.file_path[:-5] + "_" + ttype + ".json"
        with open(fpath, "w+") as ff:
            json.dump(dsplit["{}_data".format(ttype)], ff)


