def set_hyperparams(args, t):
    if args.model_name_or_path == 'bert-base-uncased':
        args.learning_rate = 2e-5
        args.num_train_epochs = 3
        args.max_lengh = 128
        args.per_device_train_batch_size = 32 # when using a singlu gpu
        if t == 'mag':
            pass
        elif t == 'imp':
            if args.prune_ratio == 0.2:
                args.num_replay = 2
            elif args.prune_ratio == 0.4:
                args.num_replay = 4
            elif args.prune_ratio == 0.6:
                args.num_replay = 6
            elif args.prune_ratio == 0.8:
                args.num_replay = 9
        elif t == 'fgmp':
            if args.task_name == 'cola':
                args.prune_end_it = 800
                args.prune_freq_it = 2
                args.lr_mask = 5e-5
            elif args.task_name == 'sst2':
                args.prune_end_it = 6300
                args.prune_freq_it = 10
                args.lr_mask = 1e-4
            elif args.task_name == 'mrpc':
                args.prune_end_it = 325
                args.prune_freq_it = 25
                args.lr_mask = 5e-5
            elif args.task_name == 'stsb':
                args.prune_end_it = 540
                args.prune_freq_it = 2
                args.lr_mask = 1e-4
            elif args.task_name == 'qqp':
                args.prune_end_it = 34100
                args.prune_freq_it = 25
                args.lr_mask = 5e-5
            elif args.task_name == 'mnli':
                args.prune_end_it = 36800
                args.prune_freq_it = 50
                args.lr_mask = 1e-4
            elif args.task_name == 'qnli':
                args.prune_end_it = 9800
                args.prune_freq_it = 25
                args.lr_mask = 1e-4
