import asyncio
import logging
from typing import Any
import ray
from omegaconf import DictConfig
from starlette.requests import Request
from starlette.responses import JSONResponse
from verl.workers.rollout.async_server import AsyncServerBase
logger = logging.getLogger(__file__)
@ray.remote(num_cpus=1)
class AsyncSGLangServer(AsyncServerBase):
    def __init__(self, config: DictConfig, dp_size: int, dp_rank: int, wg_prefix: str):
        super().__init__()
        self.config = config.actor_rollout_ref
        self._tp_size = self.config.rollout.get("tensor_model_parallel_size", 1)
        self._dp_size = dp_size
        self._dp_rank = dp_rank
        self.wg_prefix = wg_prefix
        self.workers = []
        self.master_worker = None
    async def init_engine(self):
        if self.workers:
            return
        all_actors = ray.util.list_named_actors(all_namespaces=True)
        matched_actors = [
            actor for actor in all_actors if actor.get("name", None).startswith(self.wg_prefix + "WorkerDict_")
        ]
        gpu_per_node = len(set([actor["name"].split(":")[1] for actor in matched_actors]))
        assert len(matched_actors) == self._dp_size * self._tp_size
        for matched_actor in matched_actors:
            fields = matched_actor["name"].split(":")
            assert len(fields) == 2, f"invalid actor name: {matched_actor['name']}"
            pg_index, local_rank = int(fields[0].split("_")[-1]), int(fields[1])
            current_global_rank = gpu_per_node * pg_index + local_rank
            worker_dp_rank = current_global_rank // self._tp_size
            worker_tp_rank = current_global_rank % self._tp_size
            if worker_dp_rank == self._dp_rank:
                worker = ray.get_actor(**matched_actor)
                self.workers.append(worker)
                if worker_tp_rank == 0:
                    self.master_worker = worker
    async def chat_completion(self, raw_request: Request):
        request = await raw_request.json()
        output_future = self.master_worker.chat_completion.remote(request)
        [outputs] = await asyncio.gather(output_future)
        return JSONResponse(outputs)
    async def generate(self, prompt_ids: list[int], sampling_params: dict[str, Any], request_id: str) -> list[int]:
        return await self.master_worker.generate.remote(prompt_ids, sampling_params, request_id)
    async def wake_up(self):
        if not self.config.rollout.free_cache_engine:
            return
        tasks = [worker.wake_up.remote() for worker in self.workers]
        if tasks:
            await asyncio.gather(*tasks)
    async def sleep(self):
        if not self.config.rollout.free_cache_engine:
            return
        tasks = [worker.sleep.remote() for worker in self.workers]
        if tasks:
            await asyncio.gather(*tasks)