# SPDX-License-Identifier: Apache-2.0

import openai  # use the official client for correctness check
import pytest
import pytest_asyncio

from ...utils import RemoteOpenAIServer

# any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"


@pytest.fixture(scope="module")
def server():
    args = [
        # use half precision for speed and memory savings in CI environment
        "--dtype",
        "bfloat16",
        "--max-model-len",
        "8192",
        "--enforce-eager",
        # lora config below
        "--max-num-seqs",
        "128",
        "--enable-chunked-prefill",
        "--max-num-batched-tokens",
        "1000",
        # large prompts create a lot of output
        "--disable-log-requests",
    ]

    with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
        yield remote_server


@pytest_asyncio.fixture
async def client(server):
    async with server.get_async_client() as async_client:
        yield async_client


@pytest.mark.asyncio
async def test_completion_stream_options_and_logprobs_with_long_prompts(
        client: openai.AsyncOpenAI):
    # Test stream with long prompt
    prompt = "What is the capital of France?" * 400

    stream = await client.completions.create(
        model=MODEL_NAME,
        prompt=prompt,
        max_tokens=5,
        temperature=0.0,
        stream=True,
        stream_options={
            "include_usage": True,
            "continuous_usage_stats": True,
        },
        logprobs=5,
    )

    tokens_received = 0
    finished = False
    async for chunk in stream:
        assert chunk.usage.prompt_tokens >= 0
        assert chunk.usage.completion_tokens >= 0
        assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens +
                                            chunk.usage.completion_tokens)
        if not finished:
            tokens_received += 1
            assert chunk.choices[0].text

            if chunk.choices[0].finish_reason is not None:
                finished = True

        if finished:
            assert chunk.usage.completion_tokens == tokens_received


@pytest.mark.asyncio
async def test_chat_completion_stream_options_and_logprobs_with_long_prompts(
        client: openai.AsyncOpenAI):
    # Test stream with long prompt
    messages = [{
        "role": "system",
        "content": "You are a helpful assistant."
    }, {
        "role": "user",
        "content": "What is the capital of France?" * 400
    }]
    stream = await client.chat.completions.create(
        model=MODEL_NAME,
        messages=messages,
        max_tokens=5,
        temperature=0.0,
        stream=True,
        stream_options={
            "include_usage": True,
            "continuous_usage_stats": True,
        },
        logprobs=True,
        top_logprobs=5,
    )

    tokens_received = 0
    empty_chunks_received = 0
    finished = False
    async for chunk in stream:
        assert chunk.usage.prompt_tokens >= 0
        assert chunk.usage.completion_tokens >= 0
        assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens +
                                            chunk.usage.completion_tokens)

        if not finished:
            if chunk.choices[0].delta.content == "":
                # when there is no tokens generated
                assert chunk.usage.completion_tokens == 0
                assert chunk.choices[0].logprobs is None
                empty_chunks_received += 1
            else:
                tokens_received += 1

            if chunk.choices[0].finish_reason is not None:
                finished = True

        if finished:
            assert chunk.usage.completion_tokens == tokens_received

    assert empty_chunks_received <= 1
