
from argparse import ArgumentParser, Namespace
import json

from eval_utils import DATA_NAME_TO_MAX_NEW_TOKENS


def parse_args() -> Namespace:
    p = ArgumentParser()
    p.add_argument(
        "--task",
        type=str,
        # choices=list(DATA_NAME_TO_MAX_NEW_TOKENS.keys()) + ["all"],
        required=True,
        help='Which task to use. Note that "all" can only be used in `compute_scores.py`.',  # noqa
    )
    p.add_argument(
        "--data_dir", type=str, default="../data", help="The directory of data."
    )
    p.add_argument(
        "--output_dir",
        type=str,
        default="../results",
        help="Where to dump the prediction results.",
    )  # noqa
    p.add_argument(
        "--model_name_or_path",
        type=str,
        default="facebook/opt-350m",
        help="For `compute_scores.py` only, specify which model you want to compute the score for.",  # noqa
    )
    p.add_argument(
        "--num_eval_examples",
        type=int,
        default=-1,
        help="The number of test examples to use, use all examples in default.",
    )  # noqa
    p.add_argument(
        "--start_idx",
        type=int,
        default=0,
        help="The index of the first example to infer on. This is used if you want to evaluate on a (contiguous) subset of the data.",
    )  # noqa
    p.add_argument(
        "--stop_idx",
        type=int,
        help="The index of the last example to infer on. This is used if you want to evaluate on a (contiguous) subset of the data. Defaults to the length of dataset.",
    )  # noqa
    p.add_argument("--verbose", action="store_true")
    p.add_argument("--use_sparq", action="store_true")
    p.add_argument("--device", type=str, default="cuda")
    p.add_argument("--max_seq_length", type=int, default=100000)
    p.add_argument("--rewrite", action="store_true")
    p.add_argument("--topk", type=int, default=-1)
    p.add_argument("--starting_layer", type=int, default=-1)
    p.add_argument("--start_example_id", type=int, default=0)
    p.add_argument("--topk_dims_file_path", type=str, default=None)
    p.add_argument("--kv_cache_cpu", action="store_true")
    p.add_argument("--kv_cache_cpu_device", type=str, default="cpu")
    p.add_argument(
        "--kv_type",
        type=str,
        default="dense",
        choices=[
            "dense",
            "snapkv",
            "pyramidkv",
            "quest",
            "streamingllm",
            "retr_attn",
            "kivi",
        ],
    )
    p.add_argument("--trust_remote_code", action="store_true")
    p.add_argument("--use_chat_template", action="store_true")
    p.add_argument("--same_context_different_query", action="store_true")
    p.add_argument("--tensor_parallel_size", type=int, default=1)
    p.add_argument("--max_turns", type=int, default=5)
    p.add_argument("--use_llmlingua", action="store_true")
    p.add_argument("--disable_golden_context", action="store_true")
    p.add_argument("--use_v2_data", action="store_true")
    p.add_argument(
        "--attn_type",
        type=str,
        choices=[
            "vllm",
            "vllm_minference",
            "vllm_a_shape",
            "vllm_tri_shape",
            "hf",
            "a_shape",
            "tri_shape",
            "inf_llm",
            "flash_attn",
            "minference",
            "minference_with_dense",
            "minference_with_dense_sink",
            "dilated1",
            "dilated2",
            "retrieval_attn",
            "minference_with_retr_attn",
            "vllm_kv",
            "dense",
        ],
        default="hf",
    )
    p.add_argument("--is_search", action="store_true")
    p.add_argument("--hyper_param", type=json.loads, default={})
    return p.parse_args()
