"""
vLLM backend — two modes:

  offline  (default): Loads the model in-process via vllm.LLM.
           Fastest for large-scale evaluation on a single machine.
           Requires: pip install vllm

  server:  Connects to a running vLLM OpenAI-compatible server.
           Set mode="server" and base_url="http://localhost:8000/v1".
           Requires: pip install openai  (standard OpenAI client reused)

Example offline usage:
  cfg = {"backend_type": "vllm", "model": "mistralai/Mixtral-8x7B-Instruct-v0.1",
         "tensor_parallel_size": 2}

Example server usage:
  cfg = {"backend_type": "vllm", "mode": "server",
         "base_url": "http://localhost:8000/v1",
         "model": "mistralai/Mixtral-8x7B-Instruct-v0.1"}
"""
from __future__ import annotations

from meta_rg.backends.base import BaseBackend


class VLLMBackend(BaseBackend):
    def __init__(
        self,
        model: str,
        mode: str = "offline",          # "offline" | "server"
        base_url: str = "http://localhost:8000/v1",
        api_key: str = "EMPTY",
        temperature: float = 0.7,
        max_tokens: int = 256,
        tensor_parallel_size: int = 1,
        gpu_memory_utilization: float = 0.90,
        system_prompt: str | None = None,
        **_,
    ) -> None:
        self.model = model
        self.temperature = temperature
        self.max_tokens = max_tokens
        self.system_prompt = system_prompt
        self._mode = mode
        self._seed: int | None = None

        if mode == "offline":
            try:
                from vllm import LLM, SamplingParams
            except ImportError as e:
                raise ImportError("pip install vllm") from e
            self._llm = LLM(
                model=model,
                tensor_parallel_size=tensor_parallel_size,
                gpu_memory_utilization=gpu_memory_utilization,
            )
            self._SamplingParams = SamplingParams
            self._tokenizer = self._llm.get_tokenizer()
        else:
            try:
                from openai import OpenAI
            except ImportError as e:
                raise ImportError("pip install openai") from e
            self._client = OpenAI(base_url=base_url, api_key=api_key)

    def set_seed(self, seed: int) -> None:
        if self._mode == "offline":
            self._seed = seed

    @property
    def tokenize_fn(self):
        if self._mode == "offline":
            tok = self._tokenizer
            return lambda text: len(tok.encode(text))
        return None

    def generate(self, prompt_text: str) -> str:
        messages = []
        if self.system_prompt:
            messages.append({"role": "system", "content": self.system_prompt})
        messages.append({"role": "user", "content": prompt_text})

        if self._mode == "offline":
            # Apply chat template manually then run vLLM
            formatted = self._tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
            seed_kw = {"seed": self._seed} if self._seed is not None else {}
            params = self._SamplingParams(
                temperature=self.temperature, max_tokens=self.max_tokens, **seed_kw
            )
            outputs = self._llm.generate([formatted], params)
            return outputs[0].outputs[0].text

        # Server mode — OpenAI-compatible
        response = self._client.chat.completions.create(
            model=self.model,
            messages=messages,
            temperature=self.temperature,
            max_tokens=self.max_tokens,
        )
        return response.choices[0].message.content or ""
