# SPDX-License-Identifier: Apache-2.0

import contextlib
import os
from typing import Any, List, NamedTuple

import openai  # use the official client for correctness check
import pytest

from ...utils import RemoteOpenAIServer

# # any model with a chat template should work here
MODEL_NAME = "Qwen/Qwen2-1.5B-Instruct"
DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}"""  # noqa: E501
API_KEY = "abc-123"
ERROR_API_KEY = "abc"
ROOT_PATH = "llm"


@pytest.fixture(scope="module")
def server():
    args = [
        # use half precision for speed and memory savings in CI environment
        "--dtype",
        "float16",
        "--enforce-eager",
        "--max-model-len",
        "4080",
        "--root-path",  # use --root-path=/llm for testing
        "/" + ROOT_PATH,
        "--chat-template",
        DUMMY_CHAT_TEMPLATE,
    ]
    envs = os.environ.copy()

    envs["VLLM_API_KEY"] = API_KEY
    with RemoteOpenAIServer(MODEL_NAME, args, env_dict=envs) as remote_server:
        yield remote_server


class TestCase(NamedTuple):
    model_name: str
    base_url: List[str]
    api_key: str
    expected_error: Any


@pytest.mark.asyncio
@pytest.mark.parametrize(
    "test_case",
    [
        TestCase(
            model_name=MODEL_NAME,
            base_url=["v1"],  # http://localhost:8000/v1
            api_key=ERROR_API_KEY,
            expected_error=openai.AuthenticationError),
        TestCase(
            model_name=MODEL_NAME,
            base_url=[ROOT_PATH, "v1"],  # http://localhost:8000/llm/v1
            api_key=ERROR_API_KEY,
            expected_error=openai.AuthenticationError),
        TestCase(
            model_name=MODEL_NAME,
            base_url=["v1"],  # http://localhost:8000/v1
            api_key=API_KEY,
            expected_error=None),
        TestCase(
            model_name=MODEL_NAME,
            base_url=[ROOT_PATH, "v1"],  # http://localhost:8000/llm/v1
            api_key=API_KEY,
            expected_error=None),
    ],
)
async def test_chat_session_root_path_with_api_key(server: RemoteOpenAIServer,
                                                   test_case: TestCase):
    saying: str = "Here is a common saying about apple. An apple a day, keeps"
    ctx = contextlib.nullcontext()
    if test_case.expected_error is not None:
        ctx = pytest.raises(test_case.expected_error)
    with ctx:
        client = openai.AsyncOpenAI(
            api_key=test_case.api_key,
            base_url=server.url_for(*test_case.base_url),
            max_retries=0)
        chat_completion = await client.chat.completions.create(
            model=test_case.model_name,
            messages=[{
                "role": "user",
                "content": "tell me a common saying"
            }, {
                "role": "assistant",
                "content": saying
            }],
            extra_body={
                "continue_final_message": True,
                "add_generation_prompt": False
            })

        assert chat_completion.id is not None
        assert len(chat_completion.choices) == 1
        choice = chat_completion.choices[0]
        assert choice.finish_reason == "stop"
        message = choice.message
        assert len(message.content) > 0
        assert message.role == "assistant"
