import argparse
import os


root_path = "/root/workspace/self-improvement"


def parse_args():
    parser = argparse.ArgumentParser(description='Args of vllm inference')

    # Model Args
    parser.add_argument('--model_path',
                        # default=os.path.join(root_path, "ckpts/Llama-3.1-8B-Instruct/MedQA_en_dqa_qa_lr_2e-05_init_train_True"),
                        # default=os.path.join(root_path, "ckpts/Llama-3.2-3B-Instruct/MedQA_en_dqa_qa_lr_2e-05_init_train_True"),
                        # default=os.path.join(root_path, ""),
                        type=str)
    parser.add_argument('--chat_template_name',
                        default="llama-3.1-chat",
                        # default="llama-3.2-chat",
                        type=str)
    parser.add_argument('--tgt_dir', default="vllm_outputs", type=str)
    parser.add_argument('--bf16', default=True, type=bool)
    parser.add_argument('--fp16', default=False, type=bool)
    parser.add_argument('--language', default='en', type=str)
    parser.add_argument('--dataset', default='MedQA_en', type=str)
    parser.add_argument('--dataset_type', default='dqa_qa', type=str)
    parser.add_argument('--prompt_template_dir', default='prompts', type=str)
    parser.add_argument('--seed', default=42, type=int)
    parser.add_argument('--do_eval', default=False, type=bool)


    args = parser.parse_args()
    return args