import argparse

model_names = {'base': 'bert-base-uncased',
        'twitter': 'cardiffnlp/twitter-roberta-base',
        'longformer': 'allenai/longformer-base-4096',
        }

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', default='data')
    parser.add_argument('--data', default='snli_balanced')
    parser.add_argument('--aim_exp', default='entropy-curr')
    parser.add_argument('--ckpt')
    parser.add_argument('--ckpt_dir')
    parser.add_argument('--model_name', default='base')
    parser.add_argument('--num_labels', type=int, default=3)
    parser.add_argument('--num_ent_labels', type=int, default=3)
    parser.add_argument('--epochs', type=int, default=10)
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--lr', type=float, default=1e-5)
    parser.add_argument('--curr', default='sl')
    parser.add_argument('--sl_lam', type=float, default=1)
    parser.add_argument('--dp_alpha', type=float, default=0.9)
    parser.add_argument('--aux_ent', action='store_true')
    parser.add_argument('--aux_ent2', action='store_true')
    parser.add_argument('--feed_ent', action='store_true')
    parser.add_argument('--soft_ent', action='store_true')
    parser.add_argument('--balance_logits', action='store_true')
    parser.add_argument('--balance_aux_ent', action='store_true')
    parser.add_argument('--burn_in', type=float, default=0)
    parser.add_argument('--detach_loss', action='store_true')
    parser.add_argument('--ent_alpha', type=float, default=1)
    parser.add_argument('--sl_mode', default='avg')
    parser.add_argument('--spl_mode', default='easy')
    parser.add_argument('--ent_cfg', default='6')
    parser.add_argument('--debug', action='store_true')
    parser.add_argument('--eval_only', action='store_true')
    parser.add_argument('--save_losses', action='store_true')
    parser.add_argument('--seed', default = '0')
    parser.add_argument('--study_name', default='test_study.pkl')
    parser.add_argument('--study_dir', default='studies')
    parser.add_argument('--noise', type=float, default=0.0)
    parser.add_argument('--data_fraction', type=float, default=1.0)
    parser.add_argument('--lng')
    args = parser.parse_args()
    args.seed = [int(x) for x in args.seed.split(',')]
    args.model_name = model_names.get(args.model_name, args.model_name)
    return args
