import argparse
from dataclasses import dataclass
from typing import Dict, Literal, Optional
import json


@dataclass
class ArgsType:
    model: str = "llama3.1_8b"
    job: str = "ppl"
    method: str = "fa2"
    stride: int = -1
    lora_r: int = 8
    checkpoint: Optional[str] = None

    count: int = -1
    batch_size: int = 1
    name: str = "dev"

    k: int = 512
    block_size_q: int = 64
    block_stride_q: int = 64
    block_size_k: int = 2
    block_stride_k: int = 1
    dense_queries: int = -1
    dense_layers: int = 3

    disable_prompt: bool = False
    no_sample: bool = False

    input: Optional[str] = None
    output: Optional[str] = None

    no_quantize: bool = False

    max_tokens: int = 512

    rope_method: str = "none"

    disable_flash: bool = False
    disable_sparq: bool = False
    disable_sliding_window: bool = False

    sampling_method: str = "center"

    overwrite: bool = False

    dataset: str = "wikitext"

    endpoint: str = "http://localhost:30000/"

    seed: int = 42

    pooler_method: str = "dummy"
    pooler_config: Optional[Dict[str, float]] = None


def eval_args(
    default_model="llama3.1_8b",
    default_job="ppl",
) -> ArgsType:
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, default=default_model)
    parser.add_argument("--job", type=str, default=default_job)
    parser.add_argument("--method", type=str, default=ArgsType.method)
    parser.add_argument("--stride", type=int, default=ArgsType.stride)
    parser.add_argument("--lora_r", type=int, default=ArgsType.lora_r)
    parser.add_argument("--checkpoint", type=str, default=ArgsType.checkpoint)
    parser.add_argument("--count", type=int, default=ArgsType.count)
    parser.add_argument("--batch_size", type=int, default=ArgsType.batch_size)
    parser.add_argument("--name", type=str, default=ArgsType.name)
    parser.add_argument("--block_size_q", type=int, default=ArgsType.block_size_q)
    parser.add_argument("--block_stride_q", type=int, default=ArgsType.block_stride_q)
    parser.add_argument("--block_size_k", type=int, default=ArgsType.block_size_k)
    parser.add_argument("--block_stride_k", type=int, default=ArgsType.block_stride_k)
    parser.add_argument("--k", type=int, default=ArgsType.k)
    parser.add_argument("--dense_layers", type=int, default=ArgsType.dense_layers)
    parser.add_argument("--dense_queries", type=int, default=ArgsType.dense_queries)
    parser.add_argument(
        "--disable_prompt", default=ArgsType.disable_prompt, action="store_true"
    )
    parser.add_argument("--no_sample", default=ArgsType.no_sample, action="store_true")
    parser.add_argument("--input", type=str, default=ArgsType.input)
    parser.add_argument("--output", type=str, default=ArgsType.output)
    parser.add_argument(
        "--no_quantize", default=ArgsType.no_quantize, action="store_true"
    )
    parser.add_argument("--max_tokens", type=int, default=ArgsType.max_tokens)
    parser.add_argument("--rope_method", type=str, default=ArgsType.rope_method)
    parser.add_argument(
        "--disable_flash", default=ArgsType.disable_flash, action="store_true"
    )
    parser.add_argument(
        "--disable_sparq", default=ArgsType.disable_sparq, action="store_true"
    )
    parser.add_argument(
        "--disable_sliding_window",
        default=ArgsType.disable_sliding_window,
        action="store_true",
    )
    parser.add_argument("--sampling_method", default=ArgsType.sampling_method, type=str)
    parser.add_argument("--overwrite", default=ArgsType.overwrite, action="store_true")
    parser.add_argument("--dataset", default=ArgsType.dataset, type=str)
    parser.add_argument("--endpoint", default=ArgsType.endpoint, type=str)

    # h2o
    parser.add_argument("--h2o-shift-q-pos", action="store_true")
    parser.add_argument("--h2o-streaming", action="store_true")
    parser.add_argument("--h2o-reduce-for-gqa", type=str, default="average")

    parser.add_argument('--pooler_method', type=str, default='dummy')
    parser.add_argument('--pooler_config', type=str)

    parser.add_argument("--seed", type=int, default=ArgsType.seed)
    args = parser.parse_args() # type: ArgsType

    if args.pooler_config is not None:
        assert isinstance(args.pooler_config, str)
        args.pooler_config = json.loads(args.pooler_config)

    print(args)
    return args
