import ray
import asyncio
from contextlib import asynccontextmanager
import fastapi
import uvicorn
import os
from starlette.requests import Request
from starlette.responses import JSONResponse


def _get_free_port():
    import socket

    with socket.socket() as sock:
        sock.bind(("", 0))
        return sock.getsockname()[1]


class BatchRequest:
    def __init__(self, input_obj):
        self.input_obj = input_obj
        self.future = asyncio.get_event_loop().create_future()


class ToolEngine:
    def __init__(self, tool_post_url: str, batch_size=8, max_wait_time=0.05):
        self.address = ray._private.services.get_node_ip_address()
        self.port = None
        self.tool_post_url = tool_post_url
        self.queue = []
        self.batch_size = batch_size
        self.max_wait_time = max_wait_time
        self.server_ready = asyncio.Event()
        asyncio.create_task(self._start_fastapi_server())

    async def _start_fastapi_server(self):
        @asynccontextmanager
        async def lifespan(app: fastapi.FastAPI):
            print("FastAPI startup")
            self.server_ready.set()

            print(
                "[BatchToolActor] Initializing lock and dispatch loop in startup event."
            )
            self.lock = asyncio.Lock()
            self.running = True
            self.dispatch_task = asyncio.create_task(self._dispatch_loop())

            yield

            print(
                "FastAPI shutdown, maybe address already in use, exit process immediately."
            )
            self.running = False
            os._exit(-1)

        app = fastapi.FastAPI(lifespan=lifespan)
        app.router.add_api_route(
            self.tool_post_url, self.tool_execution, methods=["POST"]
        )

        self.port = _get_free_port()
        config = uvicorn.Config(
            app, host=["::", "0.0.0.0"], port=self.port, log_level="warning"
        )
        server = uvicorn.Server(config)
        await server.serve()

    async def tool_execution(self, request: Request):
        req_json = await request.json()
        input_obj = self.parse_request(req_json)
        req = BatchRequest(input_obj)
        async with self.lock:
            self.queue.append(req)
        result = await req.future
        return JSONResponse(content=self.format_response(result))

    async def _dispatch_loop(self):
        while self.running:
            await asyncio.sleep(self.max_wait_time)
            async with self.lock:
                if not self.queue:
                    continue
                batch = self.queue[: self.batch_size]
                self.queue = self.queue[self.batch_size :]
            batch_inputs = [req.input_obj for req in batch]
            try:
                results = await self.process_batch(batch_inputs)
                assert len(results) == len(batch)
                for req, result in zip(batch, results):
                    req.future.set_result(result)
            except Exception as e:
                print("Batch processing failed:", str(e))
                for req in batch:
                    req.future.set_exception(e)

    def parse_request(self, req_json):
        return req_json["query"]

    def format_response(self, result):
        return {"result": result}

    async def get_server_address(self) -> str:
        """Get FastAPI server address."""
        await self.server_ready.wait()
        return f"{self.address}:{self.port}"

    def prepare_env(self) -> None: ...

    async def process_batch(self, batch_inputs):
        """
        Subclasses must implement this:
        Accepts a list of input objects, returns a list of results (in order).
        """
        raise NotImplementedError(
            "process_batch() must be implemented by your subclass"
        )
