# SPDX-License-Identifier: Apache-2.0

import json
from argparse import ArgumentError, ArgumentTypeError
from contextlib import nullcontext
from dataclasses import dataclass, field
from typing import Literal, Optional

import pytest

from vllm.config import PoolerConfig, config
from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs,
                                   get_type, is_not_builtin, is_type,
                                   nullable_kvs, optional_type)
from vllm.utils import FlexibleArgumentParser


@pytest.mark.parametrize(("type", "value", "expected"), [
    (int, "42", 42),
    (int, "None", None),
    (float, "3.14", 3.14),
    (float, "None", None),
    (str, "Hello World!", "Hello World!"),
    (str, "None", None),
    (json.loads, '{"foo":1,"bar":2}', {
        "foo": 1,
        "bar": 2
    }),
    (json.loads, "foo=1,bar=2", {
        "foo": 1,
        "bar": 2
    }),
    (json.loads, "None", None),
])
def test_optional_type(type, value, expected):
    optional_type_func = optional_type(type)
    context = nullcontext()
    if value == "foo=1,bar=2":
        context = pytest.warns(DeprecationWarning)
    with context:
        assert optional_type_func(value) == expected


@pytest.mark.parametrize(("type_hint", "type", "expected"), [
    (int, int, True),
    (int, float, False),
    (list[int], list, True),
    (list[int], tuple, False),
    (Literal[0, 1], Literal, True),
])
def test_is_type(type_hint, type, expected):
    assert is_type(type_hint, type) == expected


@pytest.mark.parametrize(("type_hints", "type", "expected"), [
    ({float, int}, int, True),
    ({int, tuple[int]}, int, True),
    ({int, tuple[int]}, float, False),
    ({str, Literal["x", "y"]}, Literal, True),
])
def test_contains_type(type_hints, type, expected):
    assert contains_type(type_hints, type) == expected


@pytest.mark.parametrize(("type_hints", "type", "expected"), [
    ({int, float}, int, int),
    ({int, float}, str, None),
    ({str, Literal["x", "y"]}, Literal, Literal["x", "y"]),
])
def test_get_type(type_hints, type, expected):
    assert get_type(type_hints, type) == expected


@config
@dataclass
class DummyConfigClass:
    regular_bool: bool = True
    """Regular bool with default True"""
    optional_bool: Optional[bool] = None
    """Optional bool with default None"""
    optional_literal: Optional[Literal["x", "y"]] = None
    """Optional literal with default None"""
    tuple_n: tuple[int, ...] = field(default_factory=lambda: (1, 2, 3))
    """Tuple with default (1, 2, 3)"""
    tuple_2: tuple[int, int] = field(default_factory=lambda: (1, 2))
    """Tuple with default (1, 2)"""
    list_n: list[int] = field(default_factory=lambda: [1, 2, 3])
    """List with default [1, 2, 3]"""


@pytest.mark.parametrize(("type_hint", "expected"), [
    (int, False),
    (DummyConfigClass, True),
])
def test_is_not_builtin(type_hint, expected):
    assert is_not_builtin(type_hint) == expected


def test_get_kwargs():
    kwargs = get_kwargs(DummyConfigClass)
    print(kwargs)

    # bools should not have their type set
    assert kwargs["regular_bool"].get("type") is None
    assert kwargs["optional_bool"].get("type") is None
    # optional literals should have None as a choice
    assert kwargs["optional_literal"]["choices"] == ["x", "y", "None"]
    # tuples should have the correct nargs
    assert kwargs["tuple_n"]["nargs"] == "+"
    assert kwargs["tuple_2"]["nargs"] == 2
    # lists should work
    assert kwargs["list_n"]["type"] is int
    assert kwargs["list_n"]["nargs"] == "+"


@pytest.mark.parametrize(("arg", "expected"), [
    (None, dict()),
    ("image=16", {
        "image": 16
    }),
    ("image=16,video=2", {
        "image": 16,
        "video": 2
    }),
    ("Image=16, Video=2", {
        "image": 16,
        "video": 2
    }),
])
def test_limit_mm_per_prompt_parser(arg, expected):
    """This functionality is deprecated and will be removed in the future.
    This argument should be passed as JSON string instead.
    
    TODO: Remove with nullable_kvs."""
    parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
    if arg is None:
        args = parser.parse_args([])
    else:
        args = parser.parse_args(["--limit-mm-per-prompt", arg])

    assert args.limit_mm_per_prompt == expected


def test_compilation_config():
    parser = EngineArgs.add_cli_args(FlexibleArgumentParser())

    # default value
    args = parser.parse_args([])
    assert args.compilation_config is None

    # set to O3
    args = parser.parse_args(["-O3"])
    assert args.compilation_config.level == 3

    # set to O 3 (space)
    args = parser.parse_args(["-O", "3"])
    assert args.compilation_config.level == 3

    # set to O 3 (equals)
    args = parser.parse_args(["-O=3"])
    assert args.compilation_config.level == 3

    # set to string form of a dict
    args = parser.parse_args([
        "--compilation-config",
        "{'level': 3, 'cudagraph_capture_sizes': [1, 2, 4, 8]}",
    ])
    assert (args.compilation_config.level == 3 and
            args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8])

    # set to string form of a dict
    args = parser.parse_args([
        "--compilation-config="
        "{'level': 3, 'cudagraph_capture_sizes': [1, 2, 4, 8]}",
    ])
    assert (args.compilation_config.level == 3 and
            args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8])


def test_prefix_cache_default():
    parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
    args = parser.parse_args([])

    engine_args = EngineArgs.from_cli_args(args=args)
    assert (not engine_args.enable_prefix_caching
            ), "prefix caching defaults to off."

    # with flag to turn it on.
    args = parser.parse_args(["--enable-prefix-caching"])
    engine_args = EngineArgs.from_cli_args(args=args)
    assert engine_args.enable_prefix_caching

    # with disable flag to turn it off.
    args = parser.parse_args(["--no-enable-prefix-caching"])
    engine_args = EngineArgs.from_cli_args(args=args)
    assert not engine_args.enable_prefix_caching


def test_valid_pooling_config():
    parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
    args = parser.parse_args([
        '--override-pooler-config',
        '{"pooling_type": "MEAN"}',
    ])
    engine_args = EngineArgs.from_cli_args(args=args)
    assert engine_args.override_pooler_config == PoolerConfig(
        pooling_type="MEAN", )


@pytest.mark.parametrize(
    ("arg"),
    [
        "image",  # Missing =
        "image=4,image=5",  # Conflicting values
        "image=video=4"  # Too many = in tokenized arg
    ])
def test_bad_nullable_kvs(arg):
    with pytest.raises(ArgumentTypeError):
        nullable_kvs(arg)


# yapf: disable
@pytest.mark.parametrize(("arg", "expected", "option"), [
    (None, None, "mm-processor-kwargs"),
    ("{}", {}, "mm-processor-kwargs"),
    (
        '{"num_crops": 4}',
        {
            "num_crops": 4
        },
        "mm-processor-kwargs"
    ),
    (
        '{"foo": {"bar": "baz"}}',
        {
            "foo":
            {
                "bar": "baz"
            }
        },
        "mm-processor-kwargs"
    ),
    (
        '{"cast_logits_dtype":"bfloat16","sequence_parallel_norm":true,"sequence_parallel_norm_threshold":2048}',
        {
            "cast_logits_dtype": "bfloat16",
            "sequence_parallel_norm": True,
            "sequence_parallel_norm_threshold": 2048,
        },
        "override-neuron-config"
    ),
])
# yapf: enable
def test_composite_arg_parser(arg, expected, option):
    parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
    if arg is None:
        args = parser.parse_args([])
    else:
        args = parser.parse_args([f"--{option}", arg])
    assert getattr(args, option.replace("-", "_")) == expected


def test_human_readable_model_len():
    # `exit_on_error` disabled to test invalid values below
    parser = EngineArgs.add_cli_args(
        FlexibleArgumentParser(exit_on_error=False))

    args = parser.parse_args([])
    assert args.max_model_len is None

    args = parser.parse_args(["--max-model-len", "1024"])
    assert args.max_model_len == 1024

    # Lower
    args = parser.parse_args(["--max-model-len", "1m"])
    assert args.max_model_len == 1_000_000
    args = parser.parse_args(["--max-model-len", "10k"])
    assert args.max_model_len == 10_000

    # Capital
    args = parser.parse_args(["--max-model-len", "3K"])
    assert args.max_model_len == 1024 * 3
    args = parser.parse_args(["--max-model-len", "10M"])
    assert args.max_model_len == 2**20 * 10

    # Decimal values
    args = parser.parse_args(["--max-model-len", "10.2k"])
    assert args.max_model_len == 10200
    # ..truncated to the nearest int
    args = parser.parse_args(["--max-model-len", "10.212345k"])
    assert args.max_model_len == 10212

    # Invalid (do not allow decimals with binary multipliers)
    for invalid in ["1a", "pwd", "10.24", "1.23M"]:
        with pytest.raises(ArgumentError):
            args = parser.parse_args(["--max-model-len", invalid])
