#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
OpenAI 兼容接口的 LLM 实现
用于连接 vLLM OpenAI-compat 服务 (http://host:port/v1)
"""
from typing import List, Dict, Any, Optional
import logging
import requests

from .base_llm import BaseLLM

logger = logging.getLogger(__name__)


class OpenAICompatibleLLM(BaseLLM):
    """通过 OpenAI 兼容的 /v1/chat/completions 接口访问 vLLM 服务"""

    def __init__(self, config: Dict[str, Any]):
        # 需要的配置键：server_url, served_model_name, optional api_key
        self.server_url: str = config.get("server_url", "").rstrip("/")
        self.served_model_name: str = config.get("served_model_name", "")
        self.api_key: Optional[str] = config.get("api_key")  # 可选
        self.timeout: int = config.get("timeout", 120)
        if not self.server_url or not self.served_model_name:
            raise ValueError("OpenAICompatibleLLM 需要配置 server_url 与 served_model_name")

        super().__init__(config)
        self.session = None
        self._load_model()

    def _load_model(self):
        self.session = requests.Session()
        # logger.info(f"Using OpenAI-compatible vLLM server: {self.server_url} model={self.served_model_name}")

    def _headers(self) -> Dict[str, str]:
        headers = {"Content-Type": "application/json"}
        if self.api_key:
            headers["Authorization"] = f"Bearer {self.api_key}"
        return headers

    def _post(self, path: str, payload: Dict[str, Any]) -> Dict[str, Any]:
        # 确保路径不重复 /v1
        if self.server_url.endswith('/v1') and path.startswith('/v1/'):
            path = path[3:]  # 移除开头的 /v1
        url = f"{self.server_url}{path}"
        resp = self.session.post(url, json=payload, headers=self._headers(), timeout=self.timeout)
        resp.raise_for_status()
        return resp.json()

    def generate(
        self,
        prompt: str,
        max_tokens: Optional[int] = None,
        temperature: Optional[float] = None,
        top_p: Optional[float] = None,
        repetition_penalty: Optional[float] = None,
        logprobs: Optional[int] = None,
        **kwargs,
    ) -> str:
        """将单轮 prompt 包装成 chat-completions 调用"""
        max_tokens = max_tokens or self.max_tokens
        temperature = temperature if temperature is not None else self.temperature
        top_p = top_p if top_p is not None else self.top_p
        # repetition_penalty 通常不在 OpenAI chat schema 中，跳过或透传给 vLLM 兼容参数

        # 对于微调的Alpaca格式模型，直接使用completions API
        if "finetune" in self.served_model_name.lower() or "alpaca" in self.served_model_name.lower():
            # 直接使用completions API
            completions_payload = {
                "model": self.served_model_name,
                "prompt": prompt,
                "max_tokens": max_tokens,
                "temperature": temperature,
                "top_p": top_p,
            }
            if repetition_penalty is not None:
                completions_payload["repetition_penalty"] = repetition_penalty
            if logprobs is not None:
                completions_payload["logprobs"] = logprobs
            completions_payload.update({k: v for k, v in kwargs.items() if v is not None})
            
            data = self._post("/v1/completions", completions_payload)
            try:
                # 如果请求了logprobs，返回完整响应，否则只返回文本
                if logprobs is not None:
                    return data  # 返回完整响应包含logprobs
                return data["choices"][0]["text"].strip()
            except Exception:
                logger.error(f"Unexpected response from server: {data}")
                return ""
        
        # 对于其他模型，尝试chat completions
        messages = [{"role": "user", "content": prompt}]
        payload = {
            "model": self.served_model_name,
            "messages": messages,
            "max_tokens": max_tokens,
            "temperature": temperature,
            "top_p": top_p,
        }
        if repetition_penalty is not None:
            payload["repetition_penalty"] = repetition_penalty
        if logprobs is not None:
            payload["logprobs"] = True  # Chat API使用布尔值
            payload["top_logprobs"] = logprobs
        # 允许透传其他 sampling 相关参数
        payload.update({k: v for k, v in kwargs.items() if v is not None})

        # 尝试使用chat completions，如果失败则使用completions API
        try:
            data = self._post("/v1/chat/completions", payload)
            # 如果请求了logprobs，返回完整响应，否则只返回文本
            if logprobs is not None:
                logger.info(f"返回完整响应包含logprobs: {type(data)}")
                return data  # 返回完整响应包含logprobs
            return data["choices"][0]["message"]["content"].strip()
        except Exception as chat_error:
            logger.warning(f"Chat completions failed: {chat_error}, trying completions API")
            # 回退到completions API
            completions_payload = {
                "model": self.served_model_name,
                "prompt": prompt,
                "max_tokens": max_tokens,
                "temperature": temperature,
                "top_p": top_p,
            }
            if repetition_penalty is not None:
                completions_payload["repetition_penalty"] = repetition_penalty
            if logprobs is not None:
                completions_payload["logprobs"] = logprobs
            completions_payload.update({k: v for k, v in kwargs.items() if v is not None})
            
            data = self._post("/v1/completions", completions_payload)
            try:
                # 如果请求了logprobs，返回完整响应，否则只返回文本
                if logprobs is not None:
                    return data  # 返回完整响应包含logprobs
                return data["choices"][0]["text"].strip()
            except Exception:
                logger.error(f"Unexpected response from server: {data}")
                return ""

    def batch_generate(self, prompts: List[str], **kwargs) -> List[str]:
        # 简单串行调用，保持接口一致性（可根据需要优化为并行）
        results: List[str] = []
        for p in prompts:
            results.append(self.generate(p, **kwargs))
        return results

    def chat(
        self,
        messages: List[Dict[str, str]],
        chat_templates: Dict[str, Dict[str, str]],
        **kwargs,
    ) -> str:
        """直接转发 messages 到 /v1/chat/completions，不使用模板格式化"""
        max_tokens = kwargs.get("max_tokens", self.max_tokens)
        temperature = kwargs.get("temperature", self.temperature)
        top_p = kwargs.get("top_p", self.top_p)
        repetition_penalty = kwargs.get("repetition_penalty")

        payload = {
            "model": self.served_model_name,
            "messages": messages,
            "max_tokens": max_tokens,
            "temperature": temperature,
            "top_p": top_p,
        }
        if repetition_penalty is not None:
            payload["repetition_penalty"] = repetition_penalty
        # 允许透传其他参数
        payload.update({k: v for k, v in kwargs.items() if v is not None})

        data = self._post("/v1/chat/completions", payload)
        try:
            return data["choices"][0]["message"]["content"].strip()
        except Exception:
            logger.error(f"Unexpected response from server: {data}")
            return ""
