import argparse

def config():
    parser = argparse.ArgumentParser()
    # training
    parser.add_argument("--loss", type=str, default='mse', choices={'mse', 'exponential', 'logistic'}, help='loss function')
    parser.add_argument("--reduction", type=str, default='mean', choices={'mean', 'sum'}, help='sum or avergae loss')
    parser.add_argument("--epoch", type=int, default=1000, help='number of epochs')
    parser.add_argument("--lr", type=float, default=0.1, help='learning rate')
    parser.add_argument("--reg", type=float, default=0, help='regularization, i.e. weight_decay')
    parser.add_argument("--init", type=float, default=0.01, help='weight initialization, 0 means using pytorch default initialization')
    parser.add_argument("--hid_width", type=int, default=100, help='number of hidden neurons')
    parser.add_argument("--relu", type=float, default=0, help='negative slope, 1 for linear, 0 for relu')
    parser.add_argument("--bias", action="store_true", help='bias or unbiased linear layer')

    # param for data set
    parser.add_argument("--data", type=str, default='toy', help='data type')
    parser.add_argument("--dataset_size", type=int, default=1000, help='number of training samples')
    parser.add_argument("--noise", type=float, default=0, help='std of noise in the output; 0 for noiseless')
    
    # param for network
    parser.add_argument("--depth", type=int, default=2, help='number of layers ')
    parser.add_argument("--sweep", type=str, default='single', choices={'single', 'compare', 'leaky_sweep'}, help='sweep option')

    # param for logging settings
    parser.add_argument("--track_weight", action="store_true", help="enable weights plot")

    print(parser.parse_args(), '\n')
    return parser