import argparse
import torch

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--pde', type=str, default='KdV')

    parser.add_argument('--train_samples', type=int, default=4)
    parser.add_argument('--sindy_optimizer', type=str, default='lbfgs')
    parser.add_argument('--lbfgs_subsample', type=float, default=0.5)
    parser.add_argument('--batch_size', type=int, default=8192)
    parser.add_argument('--num_epochs', type=int, default=100)
    parser.add_argument('--log_interval', type=int, default=1)
    parser.add_argument('--save_interval', type=int, default=10)
    parser.add_argument('--save_dir', type=str, default='KdV')
    parser.add_argument('--lr', type=float, default=0.1)
    parser.add_argument('--w_sindy', type=float, default=1.0)
    parser.add_argument('--sindy_reg_type', type=str, default='l1')
    parser.add_argument('--w_sindy_reg', type=float, default=0.0)
    parser.add_argument('--st_freq', type=int, default=100)
    parser.add_argument('--threshold', type=float, default=0.1)
    parser.add_argument('--print_eq', action='store_true')

    parser.add_argument('--model', type=str, default='di-sindy')
    parser.add_argument('--w_sym_reg', type=float, default=0.0)
    parser.add_argument('--noise', type=float, default=0.0)

    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--seed', type=int, default=0)

    args = parser.parse_args()
    args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
    return args