import argparse


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', '-b', default=100, type=int)
    parser.add_argument('--weight_decay', '-wd', default=1e-4, type=float)
    parser.add_argument('--epochs', '-e', default=200, type=int)
    parser.add_argument('--fine_tune_epochs', '-fe', default=50, type=int)
    parser.add_argument('--learning_rate', '-lr', default=0.1, type=float)
    parser.add_argument('--fine_tune_lr', '-flr', default=0.0025, type=float)
    parser.add_argument('--temp_init', '-t', default=5.0, type=float)
    parser.add_argument('--temp_inc', '-tinc', default=2.5, type=float)
    parser.add_argument('--init_epochs', '-ie', default=100, type=int)
    parser.add_argument('--resume', '-r', default=None, type=str)
    parser.add_argument('--fine_tune_from', '-ft', default=None, type=str)
    parser.add_argument('--gpus', '-g', default=None, type=int)
    parser.add_argument('--warmup_steps', '-w', default=5, type=int)
    parser.add_argument('--freeze_quant', '-fzq', default=False, action='store_true')
    parser.add_argument('--include_first_layer', '-if', default=False, action='store_true')
    parser.add_argument('--include_last_layer', '-il', default=False, action='store_true')
    parser.add_argument('--include_shortcut_layer', '-is', default=False, action='store_true')
    parser.add_argument('--decompose', default=None, type=int)
    parser.add_argument('--num_branches_first', '-nbf', default=2, type=int)
    parser.add_argument('--num_branches', '-nb', default=2, type=int)
    parser.add_argument('--num_branches_last', '-nbl', default=2, type=int)
    parser.add_argument('--dry_run', default=False, action='store_true')
    parser.add_argument('--gen_matrix_every_step', '-gm', default=False, action='store_true')
    parser.add_argument('--remove_portion', '-rp', default=0.9, type=float)
    parser.add_argument('--one_by_one', default=False, action='store_true')
    parser.add_argument('--dataset', '-d', required=True, choices=[
        'cifar10', 'cifar100', 'imagenet'
    ])
    parser.add_argument('--model', '-m', required=True, choices=[
        'resnet20', 'mobilenetv2'
    ])
    parser.add_argument('--print_horizontal_remove', default=False, action='store_true')
    args = parser.parse_args()

    args.exp_name = f'dbq_d{args.dataset}_m{args.model}_nb{args.num_branches}'
    if args.include_first_layer:
        args.exp_name += f'_nbf{args.num_branches_first}'
    if args.include_last_layer:
        args.exp_name += f'_nbl{args.num_branches_last}'

    return args
