import argparse

def get_args_parser():
    parser = argparse.ArgumentParser(description='Protein Captioning',
                                     add_help=True,
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    
    ## dataloader
    parser.add_argument('--batch-size', default=8, type=int, help='batch size')
    parser.add_argument('--num-workers', type=int, default=8, help='nb of workers')
    parser.add_argument('--seq-len-max', type=int, default=100, help='training protein length max')
    parser.add_argument('--seq-len-min', type=int, default=20, help='training protein length min')
    parser.add_argument('--func-len-max', type=int, default=50, help='training text length')
    parser.add_argument('--text-vocab-size', type=int, default=30522, choices=[19320, 10022, 28860, 30522], help='text vocab size')
    parser.add_argument('--protein-vocab-size', type=int, default=30, help='protein vocab size')
    parser.add_argument("--task-name", type=str, default="seq100_fun50", choices=['seq100_fun50', 'seq200_fun50', 'seq200_fun100'], help='task name')
    parser.add_argument("--seq-vocab-path", type=str, default="dataset/data/protein_vocab.txt", help='seq vocab path')
    parser.add_argument("--func-vocab-path", type=str, default="dataset/data/vocab.txt", help='func vocab path')
    parser.add_argument("--dataset-path", type=str, default="dataset/data/", help='dataset path')
    parser.add_argument("--dataset-name", type=str, default="filtered_dataset_seq100_fun50.ndjson", help='dataset name')

    ## optimization
    parser.add_argument('--total-epoch', default=100, type=int, help='number of total epochs to run')
    parser.add_argument('--total-iter', default=1000, type=int, help='number of total iterations to run')
    parser.add_argument('--warm-up-iter', default=1000, type=int, help='number of total iterations for warmup')
    parser.add_argument('--lr', default=1e-4, type=float, help='max learning rate')
    parser.add_argument('--lr-scheduler', default=[600], nargs="+", type=int, help="learning rate schedule (iterations)")
    parser.add_argument('--gamma', default=0.05, type=float, help="learning rate decay")
    parser.add_argument('--weight-decay', default=1e-6, type=float, help='weight decay') 
    parser.add_argument('--optimizer',default='adamw', type=str, choices=['adam', 'adamw'], help='disable weight decay on codebook')
    
    ## protein encoder arch
    parser.add_argument("--latent-conv-dim", type=int, default=256, help="embedding dimension")
    parser.add_argument("--latent-trans-dim", type=int, default=512, help="embedding dimension")
    parser.add_argument("--down-t", type=int, default=2, help="downsampling rate")
    parser.add_argument("--stride-t", type=int, default=2, help="stride size")
    parser.add_argument("--depth", type=int, default=3, help="depth of the network")
    parser.add_argument("--nb-layer-pe", type=int, default=4, help="nb of transformer layers")
    parser.add_argument('--nb-head-pe', type=int, default=8, help='nb of heads')
    parser.add_argument('--kernel-size', type=int, default=15, help='res conv kernel size')
    parser.add_argument("--pos-grad-pe", action='store_true', help="protein encoder position embedding grad")

    ## protein translator arch
    parser.add_argument("--block-size", type=int, default=51, help="seq len")
    parser.add_argument("--embed-dim-pt", type=int, default=512, help="embedding dimension")
    parser.add_argument("--protein_dim", type=int, default=512, help="latent dimension in the protein feature")
    parser.add_argument("--num-layers", type=int, default=8, help="nb of transformer layers")
    parser.add_argument("--n-head-pt", type=int, default=8, help="nb of heads")
    parser.add_argument("--ff-rate", type=int, default=4, help="feedforward size")
    parser.add_argument("--drop-out-rate", type=float, default=0.1, help="dropout ratio in the pos encoding")
    
    ## resume
    parser.add_argument("--resume-pth", type=str, default=None, help='resume vq pth')
    
    ## output directory 
    parser.add_argument('--out-dir', type=str, default='output/', help='output directory')
    parser.add_argument('--exp-name', type=str, default='exp_debug', help='name of the experiment, will create a file inside out-dir')

    ## other
    parser.add_argument('--print-iter', default=200, type=int, help='print frequency')
    parser.add_argument('--eval-iter', default=5000, type=int, help='evaluation frequency')
    parser.add_argument('--seed', default=1234, type=int, help='seed for initializing training. ')
    # distributed training
    parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel')

    
    return parser.parse_args()