import asyncio
import logging
import os
import socket
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from typing import Any, Optional
import fastapi
import ray
import uvicorn
from starlette.requests import Request
from starlette.responses import JSONResponse
logger = logging.getLogger(__file__)
def _get_free_port():
    with socket.socket() as sock:
        sock.bind(("", 0))
        return sock.getsockname()[1]
class AsyncServerBase(ABC):
    def __init__(self):
        self.address = ray.util.get_node_ip_address()
        self.port = None
        self.server_ready = asyncio.Event()
        asyncio.create_task(self._start_fastapi_server())
    async def _start_fastapi_server(self):
        @asynccontextmanager
        async def lifespan(app: fastapi.FastAPI):
            print(f"FastAPI listen on {self.address}:{self.port}")
            self.server_ready.set()
            yield
            print("FastAPI shutdown, maybe address already in use, exit process immediately.")
            os._exit(-1)
        app = fastapi.FastAPI(lifespan=lifespan)
        app.router.add_api_route("/v1/chat/completions", self.chat_completion, methods=["POST"])
        self.port = _get_free_port()
        config = uvicorn.Config(app, host=["::", "0.0.0.0"], port=self.port, log_level="warning")
        server = uvicorn.Server(config)
        await server.serve()
    async def get_server_address(self) -> tuple[str, int]:
        await self.server_ready.wait()
        return f"{self.address}:{self.port}"
    @abstractmethod
    async def chat_completion(self, raw_request: Request) -> JSONResponse:
        raise NotImplementedError
    @abstractmethod
    async def generate(self, prompt_ids: list[int], sampling_params: dict[str, Any], request_id: str) -> list[int]:
        raise NotImplementedError
    @abstractmethod
    async def init_engine(self):
        raise NotImplementedError
    @abstractmethod
    async def wake_up(self):
        raise NotImplementedError
    @abstractmethod
    async def sleep(self):
        raise NotImplementedError
def async_server_class(
    rollout_backend: str, rollout_backend_module: Optional[str] = None, rollout_backend_class: Optional[str] = None
) -> type[AsyncServerBase]:
    if rollout_backend_class is None and rollout_backend_module is None:
        if rollout_backend == "vllm":
            from verl.workers.rollout.vllm_rollout.vllm_async_server import AsyncvLLMServer
            return AsyncvLLMServer
        elif rollout_backend == "sglang":
            from verl.workers.rollout.sglang_rollout.async_sglang_server import AsyncSGLangServer
            return AsyncSGLangServer
        else:
            raise NotImplementedError(f"rollout backend {rollout_backend} is not supported")
    if rollout_backend_module is None or rollout_backend_class is None:
        raise ValueError("rollout_backend_module and rollout_backend_class must be both provided for customization")
    from verl.utils.import_utils import load_extern_type
    return load_extern_type(rollout_backend_module, rollout_backend_class)