import argparse

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--exp_name', action = 'store', type = str, required = True, dest = 'exp_name')
    parser.add_argument('--indices', action = 'store', nargs = '*', type = int, required = True, dest = 'indices')
    parser.add_argument('--setting', action = 'store', type = str, dest = 'setting', default = 'imagenet', choices = ['imagenet', 'cifar', 'text'])
    parser.add_argument('--model', action = 'store', type = str, dest = 'model', default = None)
    parser.add_argument('--target_type', action = 'store', type = str, dest = 'target_type', default = 'conv')
    parser.add_argument('--replace', action = 'store', type = str, dest = 'replacement_type', default = 'independent', choices = ['independent', 'sequential',
                                                                                                                                  'joint', 'progressive',
                                                                                                                                  'progressive_rn50_to_rn18',
                                                                                                                                  'progressive_align_rn18_to_rn50'])

    parser.add_argument('--seq_len', action = 'store', type = int, default = 1024, dest = 'seq_len')
    parser.add_argument('--num_workers', action = 'store', type = int, default = 32, dest = 'num_workers')
    parser.add_argument('--full_rank', action = 'store_false', dest = 'low_rank')
    parser.add_argument('--rank', action = 'store', type = int, dest = 'rank', default = 1024)
    parser.add_argument('--svd_init', action = 'store_true', dest = 'svd_init')

    parser.add_argument('--repdist', action = 'store', type = str, dest = 'repdist', default = 'CKA', choices = ['CKA', 'Procrustes', 'MSE'])
    parser.add_argument('--cka_lr', action = 'store', type = float, dest = 'cka_lr', default = 1e-2)
    parser.add_argument('--cka_epochs', action = 'store', type = int, dest = 'cka_epochs', default = 10)
    parser.add_argument('--cka_batch', action = 'store', type = int, dest = 'cka_batch', default = 256)
    parser.add_argument('--distillation', action = 'store_true', dest = 'distillation')
    parser.add_argument('--distil_temp', action = 'store', type = int, dest = 'distil_temp', default = 4)
    parser.add_argument('--distil_alpha', action = 'store', type = float, dest = 'distil_alpha', default = 0.5)
    parser.add_argument('--progressive_lrs', action = 'store', nargs = '*', default = [1e-3], type = float)
    parser.add_argument('--progressive_epochs', action = 'store', nargs = '*', default = [10], type = int)
    parser.add_argument('--reload_progressive', action = 'store_true', dest = 'reload_progressive')
    parser.add_argument('--prog_ckpt', action = 'store', type = str, dest = 'prog_ckpt', default = 'replace_progressive_3_0')
    
    parser.add_argument('--ft_batch', action = 'store', type = int, dest = 'ft_batch', default = 256)
    parser.add_argument('--ft_lr', action = 'store', type = float, dest = 'ft_lr', default = 1e-4)
    parser.add_argument('--ft_wd', action = 'store', type = float, dest = 'ft_wd', default = 1e-2)
    parser.add_argument('--ft_epochs', action = 'store', type = int, dest = 'ft_epochs', default = 50)
    parser.add_argument('--ft_eval', action = 'store', type = str, dest = 'evaluate', choices = ['steps', 'epoch'], default = 'epoch')
    parser.add_argument('--ft_eval_steps', action = 'store', type = int, dest = 'eval_steps', default = 500)
    parser.add_argument('--warmup_epochs', action = 'store', type = int, dest = 'warmup_epochs', default = 5)
    parser.add_argument('--scheduler_type', action = 'store', type = str, dest = 'scheduler_type', choices = ['cosine', 'plateau'], default = 'cosine')

    parser.add_argument('--layer_weight', action = 'store_true', dest = 'layer_weight')
    parser.add_argument('--tune_linear', action = 'store_true', dest = 'tune_linear')
    parser.add_argument('--tune_rnn', action = 'store_true', dest = 'tune_rnn')
    parser.add_argument('--use_amp', action = 'store_true', dest = 'use_amp')
    parser.add_argument('--reload_cka', action = 'store_true', dest = 'reload_cka')
    parser.add_argument('--reload_cka_model', action = 'store', type = str, dest = 'reload_cka_model', default = 'replace_6_7_8_9_10_11_12_13_14_15_16_17_18_19_20')
    parser.add_argument('--post_activation', action = 'store_true', dest = 'post_activation')
    parser.add_argument('--reload_ft', action = 'store_true', dest = 'reload_ft')
    parser.add_argument('--cka_scheduler', action = 'store_true', dest = 'cka_scheduler')
    parser.add_argument('--scheduler', action = 'store_true', dest = 'use_scheduler')
    parser.add_argument('--untrained', action = 'store_false', dest = 'pretrained')
    parser.add_argument('--task_loss', action = 'store_true', dest = 'task_loss')
    parser.add_argument('--local_loss', action = 'store_true', dest = 'local_loss')
    parser.add_argument('--replace_norm', action = 'store_true', dest = 'replace_norm')
    parser.add_argument('--baseline', action = 'store_true', dest = 'baseline')
    parser.add_argument('--weighting_strategy', action = 'store', type = str, dest = 'weighting_strategy', default = 'complicated', choices = ['simple', 'complicated'])
    parser.set_defaults(low_rank = True, tune_linear = False, reload_cka = False, reload_ft = False, scheduler = False, pretrained = True, 
                        task_loss = False, cka_scheduler = False, reload_progressive = False, svd_init = False, post_activation = False,
                        layer_weight = False, replace_norm = False, baseline = False, tune_rnn = False, use_amp = False)
    args = parser.parse_args()
    return args