import argparse

def arg_parse():
    parser = argparse.ArgumentParser(description='Demo', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--num_epochs', '-e', type=int, dest = 'num_epochs', default=200, help='NF Epochs')
    parser.add_argument('--model', '-m', type=str, default = 'nice', dest = 'model', help='Model')
    parser.add_argument('--linearinvertible', '-inv', dest= 'invlayer', action="store_true", help='Linear Invertible Layer')
    parser.add_argument('--lr', '-lr', type=float, default=1e-3, dest = 'lr', help='NF Learning Rate')
    parser.add_argument('--weight_decay', '-wd', type=float, default=1e-4, dest = 'weight_decay', help='weight decay for model')
    parser.add_argument('--batch_size', '-bs', type=int, default=100, dest  = 'batch_size', help='batch size')
    parser.add_argument('--latent_dim', '-ld', type=int, default=128, dest = 'latent_dim', help='latent dimension')
    parser.add_argument('--num_layers', '-nl', type=int, default=2, dest = 'num_layers' , help='# of Layers')
    parser.add_argument('--coupling_layers', '-co', type=int, default=10, dest = 'num_coupling_layers', help='# of Layers')
    parser.add_argument('--repeat_num', '-rp', type=int, default=10, dest = 'repeat', help='repeat number')
    parser.add_argument('--scaler', '-s', dest= 'scaler_flag', action="store_true", help='scaler_flag')

    args = parser.parse_args()
    return args