from typing import Mapping, Optional

from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template_by_model_path
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglSamplingParams

try:
    import litellm
except ImportError as e:
    litellm = e
    litellm.num_retries = 1


class LiteLLM(BaseBackend):
    def __init__(
        self,
        model_name,
        chat_template=None,
        api_key=None,
        organization: Optional[str] = None,
        base_url: Optional[str] = None,
        timeout: Optional[float] = 600,
        max_retries: Optional[int] = litellm.num_retries,
        default_headers: Optional[Mapping[str, str]] = None,
    ):
        super().__init__()

        if isinstance(litellm, Exception):
            raise litellm

        self.model_name = model_name

        self.chat_template = chat_template or get_chat_template_by_model_path(
            model_name
        )

        self.client_params = {
            "api_key": api_key,
            "organization": organization,
            "base_url": base_url,
            "timeout": timeout,
            "max_retries": max_retries,
            "default_headers": default_headers,
        }

    def get_chat_template(self):
        return self.chat_template

    def generate(
        self,
        s: StreamExecutor,
        sampling_params: SglSamplingParams,
    ):
        if s.messages_:
            messages = s.messages_
        else:
            messages = [{"role": "user", "content": s.text_}]

        ret = litellm.completion(
            model=self.model_name,
            messages=messages,
            **self.client_params,
            **sampling_params.to_litellm_kwargs(),
        )
        comp = ret.choices[0].message.content

        return comp, {}

    def generate_stream(
        self,
        s: StreamExecutor,
        sampling_params: SglSamplingParams,
    ):
        if s.messages_:
            messages = s.messages_
        else:
            messages = [{"role": "user", "content": s.text_}]

        ret = litellm.completion(
            model=self.model_name,
            messages=messages,
            stream=True,
            **self.client_params,
            **sampling_params.to_litellm_kwargs(),
        )
        for chunk in ret:
            text = chunk.choices[0].delta.content
            if text is not None:
                yield text, {}
