import argparse
import os

def setup_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_dir", type=str, default="./dataset/")
    parser.add_argument("--pickle_folder", type=str, default="./dataset/")
    parser.add_argument("--hidden_size", type=int, default=768)
    parser.add_argument("--train_batch_size", type=int, default=12)
    parser.add_argument("--eval_batch_size", type=int, default=12)

    parser.add_argument("--doc_stride", type=int, default=128)
    parser.add_argument("--max_seq_length", type=int, default=384)
    parser.add_argument("--max_mention_length", type=int, default=30)
    parser.add_argument("--max_query_length", type=int, default=64)
    parser.add_argument("--num_train_epochs", type=int, default=2)
    parser.add_argument("--debug", action="store_true")

    parser.add_argument("--learning_rate", type=float, default=3e-5)
    parser.add_argument("--warmup_steps", type=int, default=0) # Actually, 0.06*t_total
    parser.add_argument("--adam_epsilon", type=float, default=1e-8)
    parser.add_argument("--weight_decay", type=float, default=0.01) # Changed 20210913
    parser.add_argument("--bert_model", type=str, default="bert-base-uncased")

    # parser.add_argument("--model_dir", type=str, default="./save/no-name")
    parser.add_argument("--baseline", action="store_true")
    parser.add_argument("--seed", type=int, default=42)

    parser.add_argument("--version_2_with_negative", action="store_true")
    parser.add_argument("--output_dir", type=str, default="./save/checkpoint")
    parser.add_argument("--n_best_size", type=int, default=20)
    parser.add_argument("--verbose_logging", action="store_true")
    parser.add_argument("--max_answer_length", type=int, default=30)
    parser.add_argument("--save_steps", type=int, default=0)
    parser.add_argument("--gradient_accumulation_steps", type=int, default=1)

    parser.add_argument("--null_score_diff_threshold", type=float, default=0.0)
    parser.add_argument("--local_rank", type=int, default=-1)

    parser.add_argument("--read_data", action="store_true",
                        help="read data from json file")

    parser.add_argument("--do_eval", action="store_true")
    parser.add_argument("--no_report", action="store_true")
    parser.add_argument("--checkpoint", default="./save/checkpoint", type=str)
    parser.add_argument("--prediction_file", default=None, type=str)
    parser.add_argument("--lm_type", type=str, default="bert", choices=["bert", "roberta"])
    parser.add_argument("--evaluate_on_test", action="store_true")

    parser.add_argument("--eval_target", type=str, default="all", choices=["all", "with_knowledge", "without_knowledge"])
    
    args = parser.parse_args()

    args.do_lower_case = True
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir, exist_ok=True)
    print(f"SEED: {args.seed}")
    return args 