from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from __future__ import print_function

import argparse


def get_args(description='PyTorch-based Subtask State Prediction'):
    parser = argparse.ArgumentParser(description=description)
    # paths
    parser.add_argument('--train_path', type=str,
                        default='',
                        help='training CLIP data path')
    parser.add_argument('--dataset_name', type=str,
                        default='ProceL', help='dataset type (ProceL, CrossTask)')
    parser.add_argument('--task_name', type=str,
                        default='assemble_clarinet',
                        help='task folder name')
    parser.add_argument('--seq_data_path', type=str,
                        default='',
                        help='sequence data path for postprocessing (optional)')
    parser.add_argument('--eval_path', type=str,
                        default='',
                        help='evaluation CLIP data path')
    parser.add_argument('--extract_ilp', dest='extract_ilp', action='store_true',
                        help='mark whether we want to save ILP outputs.')
    parser.add_argument('--next_step_pred', dest='next_step_pred', action='store_true',
                        help='mark whether we want to save it for the next step prediction. use train for graph, test for next step')

    parser.add_argument('--cp_root', type=str, default='checkpoint',
                        help='checkpoint dir root')
    parser.add_argument('--tb_root', type=str, default='tensorboard',
                        help='log dir root')
    parser.add_argument('--with_text', action='store_true',
                        help='whether to include the text representation or not')
    parser.add_argument('--infer_only', action='store_true', help='true if we want to run inference once only')

    # input processing
    parser.add_argument('--num_thread_reader', type=int, default=8 * 1,
                        help='')

    # training params
    parser.add_argument('--resume', type=str, default='',
                        help='resume training from last checkpoint of given name')
    parser.add_argument('--optimizer', type=str, default='adam',
                        help='opt algorithm')
    parser.add_argument('--epochs', default=600, type=int, metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('--train_batch_size', type=int, default=64 * 1,
                        help='batch size for training')
    parser.add_argument('--start_from', default=0, type=int, metavar='N',
                        help='manual epoch number (useful on restarts)')
    parser.add_argument('--lr', '--learning_rate', default=0.0003, type=float,
                        metavar='LR', help='initial learning rate', dest='learning_rate_in_float')
    parser.add_argument('--warmup_steps', type=int, default=100,
                        help='')
    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                        help='momentum')
    parser.add_argument('--cudnn_benchmark', type=int, default=1,
                        help='whether to enable cuDNN benchmark')
    parser.add_argument('--pin_memory', dest='pin_memory', action='store_true',
                        help='use pin_memory')
    parser.add_argument('--weight_init', type=str, default='uniform',
                        help='weights inits for our representation learning part')
    parser.add_argument('--input_dim', type=int, default=512,
                        help='input representation dim')
    parser.add_argument('--hidden_dim', type=int, default=512,
                        help='hidden dim of Bi-LSTM layer (each LSTM)')
    parser.add_argument('--num_layers', type=int, default=1,
                        help='number of Bi-LSTM layers')
    parser.add_argument('--text_att_type', type=int, default=4,
                        help='how do we process text representation')

    parser.add_argument('--resample_lowerbound', type=float, default=0.75,
                        help='resample few tasks only')

    parser.add_argument('--text_att_n_head', type=int, default=16,
                        help='number of multi_head attention')
    parser.add_argument('--text_feedforward_dim', type=int, default=512,
                        help='size of feedforward dim in multi_head attention')


    # eval param
    parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                        help='evaluate model on validation set')
    parser.add_argument('--eval_batch_size', type=int, default=64 * 1,
                        help='batch size evaluation')
    parser.add_argument('--n_eval', type=int, default=50,
                        help='Eval frequence')

    # logging
    parser.add_argument('--verbose', type=int, default=1, help='')
    parser.add_argument('--n_display', type=int, default=25,
                        help='Information display frequence')

    # distributed processing
    parser.add_argument('--world_size', default=-1, type=int,
                        help='number of nodes for distributed training')
    parser.add_argument('--rank', default=-1, type=int,
                        help='node rank for distributed training')
    parser.add_argument('--dist_file', default='dist_file', type=str,
                        help='url used to set up distributed training')
    parser.add_argument('--dist_url', default='tcp://127.0.0.1:23456', type=str,
                        help='url used to set up distributed training')
    parser.add_argument('--dist_backend', default='nccl', type=str,
                        help='distributed backend')
    parser.add_argument('--seed', default=20220308, type=int,
                        help='seed for initializing training. ')
    parser.add_argument('--gpu', default=None, type=int,
                        help='GPU id to use.')
    parser.add_argument('--multiprocessing_distributed', action='store_true',
                        help='Use multi-processing distributed training to launch '
                             'N processes per node, which has N GPUs. This is the '
                             'fastest way to use PyTorch for either single node or '
                             'multi node data parallel training')
    args = parser.parse_args()
    return args
