import argparse


def parse():
    parser = argparse.ArgumentParser(description='PyTorch')
    # Datasets
    parser.add_argument('--train_dir', default='', type=str, help='train set dir')
    parser.add_argument('--test_dir', default='', type=str, help='test set dir')
    parser.add_argument('--ckpt', default='test', type=str, help='checkpoint')

    parser.add_argument('--optim', default='adam', type=str, help='optimizer')

    parser.add_argument('--lr', default=0.001, type=float, help='learning rate')
    parser.add_argument('--batch_size', default=32, type=int, help='batch size')
    parser.add_argument('--epoch', default=1000, type=int, help='epoch')

    parser.add_argument('--train_nframe', default=10, type=int, help='number of frames in a training segment')
    parser.add_argument('--test_nframe', default=-1, type=int, help='number of frames in a test segment')
    parser.add_argument('--workers', default=0, type=int, metavar='N', help='number of data loading workers (default: 0)')
    
    parser.add_argument('--test_freq', default=-1, type=int, help='frequency of running test')
    parser.add_argument('--use_wt_in_test', default='False', type=str, help='whether weight is used in test')
    
    parser.add_argument('--loadparam', default='__NO_PARAM__', type=str, help='put parameter file if load parameter from file')
    parser.add_argument('--loadopt', default='False', type=str, help='whether to load optimizer state')

    parser.add_argument('--train_rational', default='True', type=str, help='whether to train the rational activation')
    parser.add_argument('--beta', default=1, type=float, help='fogetting parameter')
    
    parser.add_argument('--seed', default='', type=str, help='whether to set seed')


    args = parser.parse_args()
    args.train_rational = to_bool(args.train_rational)
    args.loadopt = to_bool(args.loadopt)
    args.use_wt_in_test = to_bool(args.use_wt_in_test)
    args.seed = None if args.seed == '' else int(args.seed)
    state = {k: v for k, v in args._get_kwargs()}
    
    return args, state


def to_bool(string):
    return True if string == 'True' else False