import os
from argparse import ArgumentParser


def get_args():
    parser = ArgumentParser()
    
    parser.add_argument("--data-mode", type=str, default="clean", choices=("clean", "unlearn"))
    parser.add_argument("--file-path", type=str, required=True)
    parser.add_argument("--save-dir", type=str, required=True)
    parser.add_argument("--max-new-tokens", type=int, default=10)
    parser.add_argument("--ckpt_dir", type=str, default="")
    parser.add_argument(
        "--model-name", 
        type=str, required=True, 
        choices=(
            "mistral-7b-v0.3-bnb-4bit",
            "mistral-7b-instruct-v0.3-bnb-4bit",
            "Meta-Llama-3.1-8B-bnb-4bit",
            "Meta-Llama-3.1-8B-Instruct-bnb-4bit",
            "Meta-Llama-3.1-70B-bnb-4bit",
            "Meta-Llama-3.1-70B-Instruct-bnb-4bit",
            "Phi-3-mini-4k-instruct",
            "Phi-3-medium-4k-instruct",
            "gemma-7b-bnb-4bit",
            "gemma-7b-it-bnb-4bit",
        )
    )

    return parser.parse_args()