# coding=utf-8
import argparse

def update_args(args, ex_args):
    for k, v in ex_args.__dict__.items():
        setattr(args, k, v)

def _add_test_args(parser):
    parser.add_argument('--test_file', type=str, help='path to the test set file')
    parser.add_argument('--model_file', type=str, help='path to the model file')
    parser.add_argument('--beam_size', type=int, default=20, help='decoder beam size')
    parser.add_argument('--max_decode_step', type=int, help='maximum decode step')
    parser.add_argument('--search_order', type=str, default='dfs', help='the order of searching')

    parser.add_argument('--decode_file', type=str, default=None, help='path to save the k-best list')
    parser.add_argument('--report_file', type=str, default=None, help='path to save the report')
    parser.add_argument('--eval_file', type=str, default=None, help='path to save eval info')
    parser.add_argument('--cache_file', type=str, default=None, help='path to cache')
    parser.add_argument('--smooth', type=str, default='none', help='whether to do smoothing')
    parser.add_argument('--smooth_alpha', type=float, default=0., help='alpha in decision smoothing')

    parser.add_argument('--do_filter', default=False, action='store_true', help='no ex-guided')
    parser.add_argument('--do_naive', default=False, action='store_true', help='naive parse, no beam')
    parser.add_argument('--model_name', type=str, default='ASNParser', help='name of model')


def _add_synth_args(parser):
    _add_test_args(parser)
    parser.add_argument('--locator_file', type=str, default=None, help='error locator model')
    parser.add_argument('--do_noprune', default=False, action='store_true', help='no pruning')

def _add_model_args(parser):
    pass

def _add_train_args(parser):
    # arg_parser.add_argument('--cuda', action='store_true', default=False, help='Use gpu')
    parser.add_argument('--asdl_file', type=str, help='Path to ASDL grammar specification')
    parser.add_argument('--vocab', type=str, help='Path of the serialized vocabulary')
    parser.add_argument('--save_to', type=str, help='save the model to')
    parser.add_argument('--train_file', type=str, help='path to the training target file')
    parser.add_argument('--dev_file', type=str, help='path to the dev source file')

    parser.add_argument('--enc_hid_size', type=int, help='encoder hidden size')
    parser.add_argument('--src_emb_size', type=int, help='sentence embedding size')
    parser.add_argument('--field_emb_size', type=int, help='field embedding size')
    parser.add_argument('--dropout', type=float, default=0.3, help='dropout rate')

    parser.add_argument('--batch_size', type=int, help='batch size')
    parser.add_argument('--max_epoch', type=int, help='max epoch')

    parser.add_argument('--clip_grad', type=float, default=5.0, help='clip grad to')
    parser.add_argument('--lr', type=float, default=.001, help='learning rate')

    parser.add_argument('--log_every', type=int, help='log every iter')
    parser.add_argument('--run_val_after', type=int, help='run validation after')
    parser.add_argument('--max_decode_step', type=int, help='maximum decode step')
    parser.add_argument('--save_all', default=False, action='store_true', help='Save all intermediate checkpoints')


def _add_train_mul_args(parser):
    _add_train_args(parser)
    parser.add_argument('--io_vocab', type=str, help='Path of the serialized vocabulary')
    parser.add_argument('--io_hid_size', type=int, help='encoder hidden size')
    parser.add_argument('--io_emb_size', type=int, help='sentence embedding size')
    parser.add_argument('--io_pooling', type=str, default='max', help='pooling over io')

def _add_train_joint_args(parser):
    _add_train_mul_args(parser)
    parser.add_argument('--model_name', type=str, help='name of model')

def _add_train_locator_args(parser):
    parser.add_argument('--train_file', type=str, help='path to the training target file')
    parser.add_argument('--train_path_file', type=str, help='path to the training target file')
    parser.add_argument('--dev_file', type=str, help='path to the dev source file')
    parser.add_argument('--dev_path_file', type=str, help='path to the dev source file')
    parser.add_argument('--model_file', type=str, help='path to the model file')
    parser.add_argument('--search_order', type=str, default='dfs', help='the order of searching')


def parse_args(mode):
    parser = argparse.ArgumentParser()
    # parser.add_argument('--asdl_file', type=str, help='Path to ASDL grammar specification')

    # parser.add_argument('--vocab', type=str, help='Path of the serialized vocabulary')

    # parser.add_argument('--train_file', type=str, help='path to the training target file')
    # parser.add_argument('--dev_file', type=str, help='path to the dev source file')

    if mode == 'train':
        _add_train_args(parser)
    elif mode == 'train_mul':
        _add_train_mul_args(parser)
    elif mode == 'train_joint':
        _add_train_joint_args(parser)
    elif mode == 'test':
        _add_test_args(parser)
    elif mode == 'synthesize':
        _add_synth_args(parser)
    elif mode == 'train_searcher':
        _add_train_locator_args(parser)
    else:
        raise RuntimeError('unknown mode')
    
    args = parser.parse_args()
    print(args)
    return args
