from contextlib import contextmanager
from typing import Any, AsyncGenerator, Mapping, Optional

import cloudpickle
import zmq
import zmq.asyncio

from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig)
from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE,
                                         VLLM_RPC_HEALTHY_STR,
                                         VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
                                         RPCGenerateRequest, RPCUtilityRequest)
from vllm.inputs import PromptInputs
from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs

# Time to wait before checking it the server process is alive.
SERVER_START_TIMEOUT_MS = 1000


class AsyncEngineRPCClient:

    def __init__(self, rpc_path: str):
        self.context = zmq.asyncio.Context()
        self.rpc_path = rpc_path

    async def setup(self):
        """Setup the client before it starts sending server requests."""

        # Wait until server is ready.
        await self.wait_for_server()
        self._errored = False

        # Get the configs.
        self.model_config = await self._get_model_config_rpc()
        self.decoding_config = await self._get_decoding_config_rpc()
        self.tracing_flag = await self._is_tracing_enabled_rpc()

        # Create the tokenizer group.
        # TODO: refactor OAI server to avoid needing this info.
        self.tokenizer = init_tokenizer_from_configs(
            model_config=self.model_config,
            scheduler_config=(await self._get_scheduler_config_rpc()),
            parallel_config=(await self._get_parallel_config_rpc()),
            enable_lora=bool(await self._get_lora_config_rpc()),
        )

    def close(self):
        """Destroy the ZeroMQ Context."""
        self.context.destroy()

    @contextmanager
    def socket(self):
        # Ensure client sockets are always closed after use

        # Connect to RPC socket for Request-Reply pattern,
        # Note that we use DEALER to enable asynchronous communication
        # to enable streaming.
        socket = self.context.socket(zmq.constants.DEALER)
        try:
            socket.connect(self.rpc_path)
            yield socket
        finally:
            # linger == 0 means discard unsent messages
            # when the socket is closed. This is necessary
            # because otherwise self.context.destroy() will
            # wait for 30 seconds until unsent messages are
            # received, which is impossible if the server
            # crashed. In the absence of a server crash we
            # always expect a response before closing the
            # socket anyway.
            # Reference: http://api.zeromq.org/4-2:zmq-setsockopt#toc24
            socket.close(linger=0)

    async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,
                                         expected_type: Any,
                                         error_message: str) -> Any:
        """Send an RPC request that is expecting data back."""

        with self.socket() as socket:

            # Ping RPCServer with a request.
            await socket.send(cloudpickle.dumps(request))

            # Await the data from the Server.
            data = cloudpickle.loads(await socket.recv())

        if not isinstance(data, expected_type):
            # LoRAConfig can be None.
            if expected_type == LoRAConfig and data is None:
                pass
            else:
                raise ValueError(error_message)

        return data

    async def _send_one_way_rpc_request(self,
                                        request: RPC_REQUEST_TYPE,
                                        error_message: str,
                                        timeout: Optional[int] = None):
        """Send one-way RPC request to trigger an action."""
        with self.socket() as socket:
            # Ping RPC Server with request.
            await socket.send(cloudpickle.dumps(request))

            # Await acknowledgement from RPCServer.
            if timeout is not None and await socket.poll(timeout=timeout) == 0:
                raise TimeoutError(f"server didn't reply within {timeout} ms")

            response = cloudpickle.loads(await socket.recv())

        if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR:
            raise ValueError(error_message)

        return response

    async def get_tokenizer(self, lora_request: LoRARequest):
        return await self.tokenizer.get_lora_tokenizer_async(lora_request)

    async def get_decoding_config(self) -> DecodingConfig:
        return self.decoding_config

    async def get_model_config(self) -> ModelConfig:
        return self.model_config

    async def is_tracing_enabled(self) -> bool:
        return self.tracing_flag

    async def wait_for_server(self):
        """Wait for the RPCServer to start up."""

        await self._send_one_way_rpc_request(
            request=RPCUtilityRequest.IS_SERVER_READY,
            error_message="Unable to start RPC Server.",
            timeout=SERVER_START_TIMEOUT_MS)

    async def _get_model_config_rpc(self) -> ModelConfig:
        """Get the ModelConfig object from the RPC Server"""

        return await self._send_get_data_rpc_request(
            RPCUtilityRequest.GET_MODEL_CONFIG,
            expected_type=ModelConfig,
            error_message="Could not get ModelConfig from RPC Server")

    async def _get_decoding_config_rpc(self) -> DecodingConfig:
        """Get DecodingConfig from the RPCServer"""

        return await self._send_get_data_rpc_request(
            RPCUtilityRequest.GET_DECODING_CONFIG,
            expected_type=DecodingConfig,
            error_message="Could not get DecodingConfig from RPC Server")

    async def _get_parallel_config_rpc(self) -> ParallelConfig:
        """Get ParallelConfig from the RPCServer"""

        return await self._send_get_data_rpc_request(
            RPCUtilityRequest.GET_PARALLEL_CONFIG,
            expected_type=ParallelConfig,
            error_message="Could not get ParallelConfig from RPC Server")

    async def _get_scheduler_config_rpc(self) -> SchedulerConfig:
        """Get SchedulerConfig from the RPCServer"""

        return await self._send_get_data_rpc_request(
            RPCUtilityRequest.GET_SCHEDULER_CONFIG,
            expected_type=SchedulerConfig,
            error_message="Could not get SchedulerConfig from RPC Server")

    async def _get_lora_config_rpc(self) -> LoRAConfig:
        """Get LoRAConfig from the RPCServer"""

        return await self._send_get_data_rpc_request(
            RPCUtilityRequest.GET_LORA_CONFIG,
            expected_type=LoRAConfig,
            error_message="Could not get LoRAConfig from RPC Server")

    async def _is_tracing_enabled_rpc(self) -> bool:
        """Get is_tracing_enabled flag from the RPCServer"""

        return await self._send_get_data_rpc_request(
            RPCUtilityRequest.IS_TRACING_ENABLED,
            expected_type=bool,
            error_message="Could not get is_tracing_enabled flag from RPC "
            "Server")

    async def abort(self, request_id: str):
        """Send an ABORT_REQUEST signal to the RPC Server"""

        await self._send_one_way_rpc_request(
            request=RPCAbortRequest(request_id),
            error_message=f"RPCAbortRequest {request_id} failed")

    async def do_log_stats(self):
        """Send a DO_LOG_STATS signal to the RPC Server"""

        await self._send_one_way_rpc_request(
            request=RPCUtilityRequest.DO_LOG_STATS,
            error_message="RPCRequest DO_LOG_STATS failed.")

    @property
    def is_running(self) -> bool:
        return not self._errored

    @property
    def is_stopped(self) -> bool:
        return self._errored

    @property
    def errored(self) -> bool:
        return self._errored

    async def generate(
        self,
        inputs: PromptInputs,
        sampling_params: SamplingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None
    ) -> AsyncGenerator[RequestOutput, None]:
        """Send an RPCGenerateRequest to the RPCServer and stream responses."""

        finished = False
        try:
            with self.socket() as socket:

                # Send RPCGenerateRequest to the RPCServer.
                await socket.send_multipart([
                    cloudpickle.dumps(
                        RPCGenerateRequest(
                            inputs=inputs,
                            sampling_params=sampling_params,
                            request_id=request_id,
                            lora_request=lora_request,
                            trace_headers=trace_headers,
                            prompt_adapter_request=prompt_adapter_request))
                ])

                # Stream back the results from the RPC Server.
                while not finished:
                    message = await socket.recv()
                    request_output = cloudpickle.loads(message)

                    if isinstance(request_output, Exception):
                        # On exception, check if the server is still healthy.
                        # Use this to set the sync `is_running` and `errored`
                        # properties.
                        try:
                            await self.check_health()
                        except Exception:
                            self._errored = True
                        # NB: do before raising here so that the flag is set
                        # by the time the caller receives this exception
                        raise request_output

                    finished = request_output.finished
                    yield request_output
        finally:
            if not finished:
                await self.abort(request_id)

    async def check_health(self) -> None:
        """Raise if unhealthy"""

        with self.socket() as socket:

            # Ping RPCServer with CHECK_HEALTH request.
            await socket.send(cloudpickle.dumps(RPCUtilityRequest.CHECK_HEALTH)
                              )

            # Await the reply from the server.
            # TODO: do we need an internal timeout here?
            # Or do we expect the external probe to timeout and let this chill?
            health_message = cloudpickle.loads(await socket.recv())

        if isinstance(health_message, Exception):
            raise health_message

        if health_message != VLLM_RPC_HEALTHY_STR:
            raise ValueError("Expected healthy response from backend but got "
                             "f{health_message}")

    async def encode(self, *args,
                     **kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]:
        raise NotImplementedError(
            "Embeddings not supported with multiprocessing backend")
