

import enum
import json
from typing import List, Optional, Union
from dataclasses import dataclass, field, fields

from transformers import PretrainedConfig

from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import get_quantization_config
from vllm.transformers_utils.config import get_hf_text_config
from vllm.utils import is_hip

from vllm.config import ModelConfig, _get_and_verify_dtype, _get_and_verify_max_len

GPTQMarlinConfig = get_quantization_config("gptq_marlin")

logger = init_logger(__name__)

_GB = 1 << 30


class ModelConfig(ModelConfig):


    def __init__(
        self,
        hf_config: PretrainedConfig,
        dtype: str,
        seed: int,
        revision: Optional[str] = None,
        code_revision: Optional[str] = None,
        tokenizer_revision: Optional[str] = None,
        max_model_len: Optional[int] = None,
        quantization: Optional[str] = None,
        quantization_param_path: Optional[str] = None,
        enforce_eager: bool = False,
        max_context_len_to_capture: Optional[int] = None,
        max_seq_len_to_capture: Optional[int] = None,
        max_logprobs: int = 5,
        skip_tokenizer_init: bool = False,
        served_model_name: Optional[Union[str, List[str]]] = None,
    ) -> None:
        self.model = hf_config._name_or_path
        self.tokenizer = hf_config._name_or_path
        self.seed = seed
        self.revision = revision
        self.code_revision = code_revision
        self.tokenizer_revision = tokenizer_revision
        self.quantization = quantization
        self.quantization_param_path = quantization_param_path
        self.enforce_eager = enforce_eager
        self.max_context_len_to_capture = max_context_len_to_capture
        if self.max_context_len_to_capture is not None:
            raise ValueError("`max_context_len_to_capture` is deprecated. "
                             "Use `max_seq_len_to_capture` instead.")
        self.max_seq_len_to_capture = (max_seq_len_to_capture or max_context_len_to_capture)
        self.max_logprobs = max_logprobs
        self.skip_tokenizer_init = skip_tokenizer_init


        self.hf_config = hf_config
        self.hf_text_config = get_hf_text_config(hf_config)

        self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
        self.max_model_len = _get_and_verify_max_len(self.hf_config, max_model_len)

        self._verify_quantization()
        self._verify_cuda_graph()


class LoadFormat(str, enum.Enum):
    AUTO = 'auto'
    MEGATRON = "megatron"
    HF = "hf"
    DTENSOR = 'dtensor'
    DUMMY_HF = 'dummy_hf'
    DUMMY_MEGATRON = 'dummy_megatron'
    DUMMY_DTENSOR = 'dummy_dtensor'


@dataclass
class LoadConfig:


    load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
    download_dir: Optional[str] = None
    model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict)

    def __post_init__(self):
        model_loader_extra_config = self.model_loader_extra_config or {}
        if isinstance(model_loader_extra_config, str):
            self.model_loader_extra_config = json.loads(model_loader_extra_config)
        self._verify_load_format()

    def _verify_load_format(self) -> None:
        if not isinstance(self.load_format, str):
            return

        load_format = self.load_format.lower()
        self.load_format = LoadFormat(load_format)

        rocm_not_supported_load_format: List[str] = []
        if is_hip() and load_format in rocm_not_supported_load_format:
            rocm_supported_load_format = [
                f for f in LoadFormat.__members__ if (f not in rocm_not_supported_load_format)
            ]
            raise ValueError(f"load format '{load_format}' is not supported in ROCm. "
                             f"Supported load formats are "
                             f"{rocm_supported_load_format}")
