# serve_qwen3vl.py
import os, io, time, base64, asyncio, argparse, re, copy
from typing import List, Dict, Any, Tuple, Optional
from dataclasses import dataclass
from collections import defaultdict

import torch
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
import uvicorn

from transformers import AutoProcessor

try:
    from qwen_vl_utils import process_vision_info
except Exception as e:
    process_vision_info = None

try:
    import requests
except Exception:
    requests = None

from vllm import LLM, SamplingParams

os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn")



# -------------------- args (kept aligned with serve_opencua.py; extra vLLM args are optional) --------------------
def get_args():
    p = argparse.ArgumentParser()
    p.add_argument("--model", required=True, help="HF model path, e.g. Qwen/Qwen3-VL-8B-Thinking")
    p.add_argument("--port", type=int, default=8000)
    p.add_argument("--host", type=str, default="0.0.0.0")
    p.add_argument("--dtype", type=str, default="bfloat16", choices=["float16", "bfloat16", "float32"])
    p.add_argument("--max-batch", type=int, default=8)
    p.add_argument("--queue-ms", type=int, default=20)
    p.add_argument("--device", type=str, default="cuda")  # kept for compatibility; vLLM uses CUDA when available
    p.add_argument("--idle-unload-s", type=int, default=0)
    p.add_argument("--offload-mode", type=str, default="none", choices=["cpu", "disk", "none"])
    p.add_argument("--preload", action="store_true")

    # multi-GPU sharding options (kept names for compatibility)
    p.add_argument("--num-gpu-per-model", type=int, default=1, help="Use N visible GPUs per model via tensor parallel")
    p.add_argument("--device-map", type=str, default="auto", choices=["auto", "balanced", "sequential"],
                   help="(ignored in vLLM) kept only for CLI compatibility")
    p.add_argument("--max-gpu-mem", type=str, default="", help='Per-GPU cap like "70GiB"; default≈90% of total')

    # vLLM extras (safe defaults)
    p.add_argument("--max-model-len", type=int, default=0, help="<=0 uses model default")
    p.add_argument("--enforce-eager", action="store_true", help="force eager mode (sometimes helps debugging)")
    return p.parse_args()


def _parse_mem_to_bytes(s: str) -> Optional[int]:
    """
    Parse strings like: 70GiB, 70000MiB, 70GB, 70G, 8000MB, 8000M.
    """
    s = (s or "").strip()
    if not s:
        return None
    m = re.fullmatch(r"(?i)\s*([0-9]+(?:\.[0-9]+)?)\s*([kmgt]?i?b|[kmgt])\s*", s)
    if not m:
        return None
    val = float(m.group(1))
    unit = m.group(2).lower()

    # normalize units
    if unit in ("b",):
        mul = 1
    elif unit in ("kb", "k"):
        mul = 1000
    elif unit in ("kib",):
        mul = 1024
    elif unit in ("mb", "m"):
        mul = 1000 ** 2
    elif unit in ("mib",):
        mul = 1024 ** 2
    elif unit in ("gb", "g"):
        mul = 1000 ** 3
    elif unit in ("gib",):
        mul = 1024 ** 3
    elif unit in ("tb", "t"):
        mul = 1000 ** 4
    elif unit in ("tib",):
        mul = 1024 ** 4
    else:
        return None
    return int(val * mul)


def _first_not_none(d: dict, keys):
    for k in keys:
        v = d.get(k, None)
        if v is not None:
            return v
    return None


class Qwen3VLLMModel:
    """
    vLLM-backed Qwen3-VL server model with same API surface as serve_opencua.py:
    - ensure_loaded_on_gpu()
    - offload(mode)
    - generate_batch(reqs)->List[str]
    """
    def __init__(
        self,
        model_path: str,
        dtype: str = "bfloat16",
        num_gpu_per_model: int = 1,
        max_gpu_mem: str = "",
        max_model_len: int = 0,
        enforce_eager: bool = False,
    ):
        if process_vision_info is None:
            raise RuntimeError(
                "qwen_vl_utils.process_vision_info is not available. "
                "Please `pip install qwen-vl-utils` (or ensure it is importable)."
            )

        self.model_path = model_path
        self.req_dtype_str = dtype
        self.num_gpu_per_model = max(1, int(num_gpu_per_model))
        self.max_gpu_mem = (max_gpu_mem or "").strip()
        self.max_model_len = int(max_model_len)
        self.enforce_eager = bool(enforce_eager)

        self.sharded = self.num_gpu_per_model > 1

        # processor (Qwen3-VL uses processor chat_template + vision utils)
        self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)

        self.llm: Optional[LLM] = None
        self.loaded = False
        self.on_gpu = False
        self.loading_lock = asyncio.Lock()

        torch.set_grad_enabled(False)
        if torch.cuda.is_available():
            torch.backends.cuda.matmul.allow_tf32 = True
            torch.backends.cudnn.benchmark = True

    def _visible_cuda_ids(self) -> List[int]:
        if not torch.cuda.is_available():
            return []
        count = torch.cuda.device_count()
        want = max(1, self.num_gpu_per_model)
        return list(range(min(count, want)))

    def _compute_gpu_mem_utilization(self) -> float:
        """
        vLLM uses gpu_memory_utilization (fraction).
        If --max-gpu-mem is provided (e.g. 70GiB), convert to a safe fraction based on each visible GPU total memory.
        Otherwise default to 0.90 (aligned with your serve_opencua.py default behavior).
        """
        ids = self._visible_cuda_ids()
        if not ids:
            return 0.0

        if not self.max_gpu_mem:
            return 0.90

        cap_bytes = _parse_mem_to_bytes(self.max_gpu_mem)
        if cap_bytes is None:
            # Couldn't parse => fall back
            return 0.90

        fracs = []
        for i in ids:
            total = torch.cuda.get_device_properties(i).total_memory
            fracs.append(cap_bytes / float(total))

        util = min(fracs) if fracs else 0.90
        # clamp to sane range
        util = max(0.05, min(util, 0.99))
        return util

    async def ensure_loaded_on_gpu(self):
        async with self.loading_lock:
            if self.llm is not None:
                return

            tp = max(1, min(self.num_gpu_per_model, torch.cuda.device_count() if torch.cuda.is_available() else 1))
            util = self._compute_gpu_mem_utilization()

            print(
                f"[serve] visible CUDA ids = {self._visible_cuda_ids()}, "
                f"tensor_parallel_size={tp}, gpu_memory_utilization={util:.3f}, "
                f"max_model_len={self.max_model_len if self.max_model_len > 0 else None}, "
                f"dtype={self.req_dtype_str}",
                flush=True,
            )

            # NOTE: vLLM loads onto GPU(s) directly when CUDA available.
            self.llm = LLM(
                model=self.model_path,
                trust_remote_code=True,
                dtype=self.req_dtype_str,
                tensor_parallel_size=tp,
                gpu_memory_utilization=util if torch.cuda.is_available() else 0.0,
                max_model_len=self.max_model_len if self.max_model_len > 0 else None,
                enforce_eager=self.enforce_eager,
            )

            self.loaded = True
            self.on_gpu = torch.cuda.is_available()
            print(f"[serve] ensure_loaded_on_gpu -> loaded={self.loaded}, on_gpu={self.on_gpu}, sharded={self.sharded}", flush=True)

    async def offload(self, mode: str):
        """
        vLLM can't "move to cpu" like HF; best-effort:
        - cpu: treated as disk (delete engine)
        - disk: delete engine
        """
        async with self.loading_lock:
            if self.llm is None:
                return
            if mode in ("cpu", "disk"):
                if mode == "cpu":
                    print("[serve] offload_mode=cpu not supported by vLLM; unloading engine instead.", flush=True)
                try:
                    del self.llm
                finally:
                    self.llm = None
                    self.loaded = False
                    self.on_gpu = False
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()

    # -------- image decode --------
    def _decode_image(self, src: str) -> Image.Image:
        s = (src or "").strip()
        # local path
        if os.path.exists(s):
            return Image.open(s).convert("RGB")

        # http(s)
        if re.match(r"^https?://", s):
            if requests is None:
                raise RuntimeError("requests not available to fetch http(s) image.")
            resp = requests.get(s, timeout=15)
            resp.raise_for_status()
            return Image.open(io.BytesIO(resp.content)).convert("RGB")

        # data url / base64
        b64 = s.split(",", 1)[1] if "," in s else s
        raw = base64.b64decode(b64, validate=False)
        return Image.open(io.BytesIO(raw)).convert("RGB")

    def _resolve_images_in_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """
        Convert any image/image_url chunks into PIL.Image so process_vision_info can consume them robustly.
        Supports:
        - local file path
        - http(s) url
        - data:... base64 url or raw base64
        """
        msgs = copy.deepcopy(messages)
        for m in msgs:
            content = m.get("content")
            if not isinstance(content, list):
                continue
            new_content = []
            for part in content:
                if not isinstance(part, dict):
                    new_content.append(part)
                    continue
                t = part.get("type")
                if t == "image":
                    src = part.get("image")
                    if isinstance(src, Image.Image):
                        new_content.append(part)
                        continue
                    if isinstance(src, str) and src.strip():
                        img = self._decode_image(src)
                        new_content.append({"type": "image", "image": img})
                    else:
                        new_content.append(part)
                elif t == "image_url":
                    url = (part.get("image_url") or {}).get("url")
                    if isinstance(url, Image.Image):
                        new_content.append({"type": "image", "image": url})
                        continue
                    if isinstance(url, str) and url.strip():
                        img = self._decode_image(url)
                        # normalize into "image" type for Qwen templates
                        new_content.append({"type": "image", "image": img})
                    else:
                        new_content.append(part)
                else:
                    new_content.append(part)
            m["content"] = new_content
        return msgs

    def _prepare_inputs_for_vllm(self, messages: List[Dict[str, Any]]) -> Dict[str, Any]:
        """
        vLLM multimodal input dict (aligned with your reference code):
          {"prompt": text, "multi_modal_data": {"image": ...}, "mm_processor_kwargs": ...}
        """
        messages = self._resolve_images_in_messages(messages)
        text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

        patch_size = getattr(getattr(self.processor, "image_processor", object()), "patch_size", 14)
        if isinstance(patch_size, (tuple, list)):
            patch_size = int(patch_size[0]) if patch_size else 14

        try:
            image_inputs, video_inputs, video_kwargs = process_vision_info(
                messages,
                image_patch_size=int(patch_size),
                return_video_kwargs=True,
                return_video_metadata=True,
            )
        except TypeError:
            image_inputs, video_inputs, video_kwargs = process_vision_info(
                messages,
                image_patch_size=int(patch_size),
                return_video_kwargs=True,
            )

        mm_data: Dict[str, Any] = {}
        if image_inputs is not None:
            mm_data["image"] = image_inputs
        if video_inputs is not None:
            mm_data["video"] = video_inputs

        return {"prompt": text, "multi_modal_data": mm_data, "mm_processor_kwargs": (video_kwargs or {})}

    def generate_batch(self, reqs: List[Dict[str, Any]]) -> List[str]:
        """
        Returns one text per request, in the same order.
        Supports per-request params: max_tokens, temperature, top_p, stop.
        """
        assert self.llm is not None, "LLM engine not loaded"
        outs: List[str] = [""] * len(reqs)

        # group by sampling params so microbatching is meaningful (but still preserves per-request variability)
        groups: Dict[Tuple[int, float, float, Tuple[str, ...]], List[Tuple[int, Dict[str, Any]]]] = defaultdict(list)

        for i, r in enumerate(reqs):
            prepared = self._prepare_inputs_for_vllm(r["messages"])

            max_tokens = int(r.get("max_tokens", 512))
            temperature = float(r.get("temperature", 0.0))
            top_p = float(r.get("top_p", 0.9))

            stop = r.get("stop", None)
            stop_list: List[str] = []
            if isinstance(stop, str) and stop.strip():
                stop_list = [stop]
            elif isinstance(stop, list):
                stop_list = [str(x) for x in stop if str(x)]
            stop_key = tuple(stop_list)

            key = (max_tokens, temperature, top_p, stop_key)
            groups[key].append((i, prepared))

        for (max_tokens, temperature, top_p, stop_key), items in groups.items():
            sampling = SamplingParams(
                temperature=temperature,
                top_p=top_p,
                max_tokens=max_tokens,
                stop=list(stop_key) if stop_key else None,
            )
            batch_inputs = [prep for (_idx, prep) in items]
            results = self.llm.generate(batch_inputs, sampling_params=sampling)

            # vLLM preserves order vs input list
            for (orig_i, _), out in zip(items, results):
                txt = out.outputs[0].text if (out.outputs and out.outputs[0] is not None) else ""
                outs[orig_i] = txt

        return outs


@dataclass
class PendingItem:
    payload: Dict[str, Any]
    fut: asyncio.Future


class MicroBatcher:
    def __init__(self, model: Qwen3VLLMModel, max_batch: int, queue_ms: int):
        self.model = model
        self.max_batch = max_batch
        self.queue_ms = queue_ms / 1000.0
        self.q: asyncio.Queue[PendingItem] = asyncio.Queue()
        self.busy = False

    async def submit(self, payload: Dict[str, Any]) -> str:
        fut = asyncio.get_event_loop().create_future()
        await self.q.put(PendingItem(payload, fut))
        return await fut

    async def loop(self):
        while True:
            item = await self.q.get()
            batch = [item]
            t0 = time.time()
            while len(batch) < self.max_batch and (time.time() - t0) < self.queue_ms:
                try:
                    more = self.q.get_nowait()
                    batch.append(more)
                except asyncio.QueueEmpty:
                    await asyncio.sleep(0.001)

            await self.model.ensure_loaded_on_gpu()
            self.busy = True
            try:
                outputs = self.model.generate_batch([x.payload for x in batch])
                assert len(outputs) == len(batch)
                for ans, itm in zip(outputs, batch):
                    if not itm.fut.done():
                        itm.fut.set_result(ans)
            except Exception as e:
                for itm in batch:
                    if not itm.fut.done():
                        itm.fut.set_exception(e)
            finally:
                self.busy = False


def create_app(args):
    app = FastAPI()
    mdl = Qwen3VLLMModel(
        model_path=args.model,
        dtype=args.dtype,
        num_gpu_per_model=args.num_gpu_per_model,
        max_gpu_mem=args.max_gpu_mem,
        max_model_len=args.max_model_len,
        enforce_eager=args.enforce_eager,
    )
    batcher = MicroBatcher(mdl, max_batch=args.max_batch, queue_ms=args.queue_ms)

    class State:
        last_request_ts = time.time()
        idle_unload_s = args.idle_unload_s
        offload_mode = args.offload_mode

    @app.get("/ready")
    async def ready():
        return {"loaded": mdl.loaded, "on_gpu": mdl.on_gpu, "sharded": mdl.sharded}

    @app.on_event("startup")
    async def _startup():
        asyncio.create_task(batcher.loop())
        if args.preload:
            await mdl.ensure_loaded_on_gpu()
            print("[serve] preload done", flush=True)

        async def idle_offloader():
            while True:
                await asyncio.sleep(1.0)
                if State.idle_unload_s <= 0 or State.offload_mode == "none":
                    continue
                idle_for = time.time() - State.last_request_ts
                if idle_for >= State.idle_unload_s and batcher.q.empty() and not batcher.busy:
                    await mdl.offload(State.offload_mode)
        asyncio.create_task(idle_offloader())

    @app.get("/health")
    async def health():
        return {"ok": True}

    @app.post("/v1/chat/completions")
    async def chat_completions(req: Request):
        body = await req.json()
        try:
            payload = {
                "messages": body["messages"],
                "max_tokens": body.get("max_tokens", 512),
                "temperature": body.get("temperature", 0.0),
                "top_p": body.get("top_p", 0.9),
            }
            # Optional: support OpenAI-like "stop"
            if "stop" in body:
                payload["stop"] = body["stop"]

            State.last_request_ts = time.time()
            text = await batcher.submit(payload)
            return JSONResponse({
                "choices": [{
                    "message": {"content": text},
                    "finish_reason": "stop"
                }]
            })
        except Exception as e:
            return JSONResponse({"error": str(e)}, status_code=500)

    return app


if __name__ == "__main__":
    args = get_args()
    uvicorn.run(create_app(args), host=args.host, port=args.port, log_level="info")
