# SPDX-License-Identifier: Apache-2.0
"""Test that various errors are handled properly."""

import asyncio
import tempfile
import time
import uuid
from unittest.mock import Mock

import pytest

from tests.mq_llm_engine.utils import RemoteMQLLMEngine
from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.engine.multiprocessing import MQEngineDeadError
from vllm.engine.multiprocessing.engine import MQLLMEngine
from vllm.entrypoints.openai.api_server import build_async_engine_client
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.lora.request import LoRARequest
from vllm.sequence import SequenceGroupMetadata
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser

MODEL = "google/gemma-1.1-2b-it"
ENGINE_ARGS = AsyncEngineArgs(model=MODEL, enforce_eager=True)
RAISED_ERROR = KeyError
RAISED_VALUE = "foo"


@pytest.fixture(scope="function")
def tmp_socket():
    with tempfile.TemporaryDirectory() as td:
        yield f"ipc://{td}/{uuid.uuid4()}"


def run_with_evil_forward(engine_args: AsyncEngineArgs, ipc_path: str):
    # Make engine.
    engine = MQLLMEngine.from_engine_args(
        engine_args=engine_args,
        usage_context=UsageContext.UNKNOWN_CONTEXT,
        ipc_path=ipc_path)

    # Raise error during first forward pass.
    engine.engine.model_executor.execute_model = Mock(
        side_effect=RAISED_ERROR(RAISED_VALUE))

    # Run engine.
    engine.start()


@pytest.mark.asyncio
async def test_evil_forward(tmp_socket):
    with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
                           ipc_path=tmp_socket,
                           run_fn=run_with_evil_forward) as engine:

        client = await engine.make_client()

        # Server should be healthy after initial probe.
        await asyncio.sleep(2.0)
        await client.check_health()

        # Throws an error that should get ENGINE_DEAD_ERROR.
        with pytest.raises(MQEngineDeadError):
            async for _ in client.generate(prompt="Hello my name is",
                                           sampling_params=SamplingParams(),
                                           request_id=uuid.uuid4()):
                pass
        assert client.errored

        await asyncio.sleep(1.0)
        with pytest.raises(RAISED_ERROR):
            await client.check_health()
        assert client.errored

        # Shutdown.
        client.close()


def run_with_evil_model_executor_health(engine_args: AsyncEngineArgs,
                                        ipc_path: str):
    # Make engine.
    engine = MQLLMEngine.from_engine_args(
        engine_args=engine_args,
        usage_context=UsageContext.UNKNOWN_CONTEXT,
        ipc_path=ipc_path)

    # Raise error during first forward pass.
    engine.engine.model_executor.check_health = Mock(side_effect=RAISED_ERROR)

    # Run engine.
    engine.start()


@pytest.mark.asyncio
async def test_failed_health_check(tmp_socket):
    with RemoteMQLLMEngine(
            engine_args=ENGINE_ARGS,
            ipc_path=tmp_socket,
            run_fn=run_with_evil_model_executor_health) as engine:

        client = await engine.make_client()
        assert client.is_running

        # Health probe should throw RAISED_ERROR.
        await asyncio.sleep(15.)

        with pytest.raises(RAISED_ERROR):
            await client.check_health()
        assert client.errored

        # Generate call should throw ENGINE_DEAD_ERROR
        with pytest.raises(MQEngineDeadError):
            async for _ in client.generate(prompt="Hello my name is",
                                           sampling_params=SamplingParams(),
                                           request_id=uuid.uuid4()):
                pass

        client.close()


def run_with_evil_abort(engine_args: AsyncEngineArgs, ipc_path: str):
    # Make engine.
    engine = MQLLMEngine.from_engine_args(
        engine_args=engine_args,
        usage_context=UsageContext.UNKNOWN_CONTEXT,
        ipc_path=ipc_path)

    # Raise error during abort call.
    engine.engine.abort_request = Mock(side_effect=RAISED_ERROR)

    # Run engine.
    engine.start()


@pytest.mark.asyncio
async def test_failed_abort(tmp_socket):
    with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
                           ipc_path=tmp_socket,
                           run_fn=run_with_evil_abort) as engine:

        client = await engine.make_client()
        assert client.is_running

        # First check health should work.
        await client.check_health()

        # Trigger an abort on the client side.
        # This request ID does not exist, and will cause the engine to error
        await client.abort(request_id="foo")

        # Future generation requests will now fail
        # with reference to the original KeyError("foo")
        with pytest.raises(MQEngineDeadError) as execinfo:
            async for _ in client.generate(
                    prompt="Hello my name is",
                    sampling_params=SamplingParams(max_tokens=10),
                    request_id=uuid.uuid4()):
                pass
        assert "KeyError" in repr(execinfo.value)
        assert client.errored

        # This should raise the original error.
        with pytest.raises(RAISED_ERROR):
            await client.check_health()

        client.close()


@pytest.mark.asyncio
async def test_batch_error(tmp_socket):
    with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
                           ipc_path=tmp_socket,
                           run_fn=run_with_evil_abort) as engine:

        client = await engine.make_client()
        assert client.is_running

        # First check health should work.
        await client.check_health()

        # Batch of requests
        async def do_generate(client):
            # min_tokens=2048 to keep busy the engine busy
            # to get enough time to get process a request
            # that will crash the engine
            params = SamplingParams(min_tokens=2048, max_tokens=2048)
            async for _ in client.generate(prompt="Hello my name is",
                                           sampling_params=params,
                                           request_id=uuid.uuid4()):
                pass

        tasks = [asyncio.create_task(do_generate(client)) for _ in range(10)]

        # This request will force a processing batch to raise
        # an exception and next the engine get errored
        await client.abort(request_id="foo")

        # The batch of those request failed, then they
        # should get the same exception as a MQEngineDeadError.
        errors = await asyncio.gather(*tasks, return_exceptions=True)
        for e in errors:
            assert isinstance(e, MQEngineDeadError)
            assert "KeyError" in repr(e)

        client.close()


@pytest.mark.asyncio
async def test_bad_request(tmp_socket):
    with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
                           ipc_path=tmp_socket) as engine:

        client = await engine.make_client()

        # Invalid request should fail, but not crash the server.
        with pytest.raises(ValueError):
            async for _ in client.generate(prompt="Hello my name is",
                                           sampling_params=SamplingParams(),
                                           request_id="abcd-1",
                                           lora_request=LoRARequest(
                                               "invalid-lora", 1,
                                               "invalid-path")):
                pass

        # This request should be okay.
        async for _ in client.generate(prompt="Hello my name is",
                                       sampling_params=SamplingParams(),
                                       request_id="abcd-2"):
            pass

        # Shutdown.
        client.close()


@pytest.mark.asyncio
async def test_mp_crash_detection(monkeypatch: pytest.MonkeyPatch):
    with monkeypatch.context() as m:

        parser = FlexibleArgumentParser(
            description="vLLM's remote OpenAI server.")
        parser = make_arg_parser(parser)
        args = parser.parse_args([])

        # When LLMEngine is loaded, it will crash.
        def mock_init():
            raise ValueError

        m.setattr(LLMEngine, "__init__", mock_init)

        start = time.perf_counter()
        async with build_async_engine_client(args):
            pass
        end = time.perf_counter()

        assert end - start < 60, (
            "Expected vLLM to gracefully shutdown in <60s "
            "if there is an error in the startup.")


@pytest.mark.asyncio
async def test_mp_cuda_init():
    # it should not crash, when cuda is initialized
    # in the API server process
    import torch
    torch.cuda.init()
    parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.")
    parser = make_arg_parser(parser)
    args = parser.parse_args([])

    async with build_async_engine_client(args):
        pass


@pytest.mark.asyncio
async def test_engine_process_death(tmp_socket):
    with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
                           ipc_path=tmp_socket) as engine:

        client = await engine.make_client()
        assert client.is_running

        # kill the engine process
        engine.proc.kill()

        # Generate call should fail
        with pytest.raises(MQEngineDeadError):
            async for _ in client.generate(prompt="Hello my name is",
                                           sampling_params=SamplingParams(),
                                           request_id=uuid.uuid4()):
                pass

        # And the health check should show the engine is dead
        with pytest.raises(RuntimeError, match="Engine process .* died"):
            await client.check_health()

        client.close()


def run_with_evil_input_processing(engine_args: AsyncEngineArgs,
                                   ipc_path: str):
    """Simulate an exception while preparing inputs for the model.
    In the wild, this could be something like a multimodal input processor
    failing on invalid image data."""

    # Make engine.
    engine = MQLLMEngine.from_engine_args(
        engine_args=engine_args,
        usage_context=UsageContext.UNKNOWN_CONTEXT,
        ipc_path=ipc_path)

    runner = engine.engine.model_executor.driver_worker.worker.model_runner

    # Raise error in the model runner when adding a sequence group.
    # See class ModelInputForGPUBuilder
    def raiser(_, seq_group_metadata: SequenceGroupMetadata):
        if seq_group_metadata.request_id.startswith("evil"):
            raise RAISED_ERROR(RAISED_VALUE)

    runner.builder.per_seq_group_compute_fns.append(raiser)

    # Run engine.
    engine.start()


@pytest.mark.asyncio
async def test_failed_inputs(tmp_socket):
    with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
                           ipc_path=tmp_socket,
                           run_fn=run_with_evil_input_processing) as engine:

        client = await engine.make_client()
        assert client.is_running

        # Engine should be healthy
        await client.check_health()

        async def run_failing_request():
            async for _ in client.generate(
                    prompt="Hello my name is",
                    sampling_params=SamplingParams(max_tokens=10),
                    request_id="evil" + str(uuid.uuid4())):
                pass

        async def run_passing_request():
            async for _ in client.generate(
                    prompt="Hello my name is",
                    sampling_params=SamplingParams(max_tokens=10),
                    request_id=str(uuid.uuid4())):
                pass

        passing_tasks = [
            asyncio.create_task(run_passing_request()) for _ in range(10)
        ]
        failing_tasks = [
            asyncio.create_task(run_failing_request()) for _ in range(10)
        ]
        await asyncio.gather(*failing_tasks, return_exceptions=True)
        await asyncio.gather(*passing_tasks)

        # All the bad inputs should have raised
        for task in failing_tasks:
            with pytest.raises(RAISED_ERROR):
                task.result()

        # But all good inputs should have still succeeded
        for task in passing_tasks:
            task.result()

        # And the engine should remain healthy
        assert not client.errored
        await client.check_health()

        client.close()
