import json
import logging
import os
import random
import time
from collections.abc import Callable
from typing import Any

import litellm
from openai import AzureOpenAI, OpenAI
from openai.types.chat import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageToolCall
from openai.types.chat.chat_completion_message_tool_call import Function

from lita.core.protos import Message, Role

logger = logging.getLogger(__name__)


def call_llm_with_retries(
    completion_fn: Callable,
    max_retries: int = 6,
    base_delay: float = 1.0,
    max_delay: float = 60.0,
    exponential_base: float = 2.0,
    jitter: bool = True,
    **kwargs,
) -> Any:
    """
    Call the LLM with retries. Exponential backoff.
    """
    result = None
    for i in range(max_retries):
        try:
            result = completion_fn(**kwargs)
        except Exception as e:
            # Calculate delay using exponential backoff
            delay = min(base_delay * (exponential_base**i), max_delay)

            # Add jitter to prevent thundering herd problem
            if jitter:
                delay *= 0.5 + random.random() * 0.5  # Random factor between 0.5 and 1.0

            logger.warning(f"{e}\nFailed to call LLM, retry in {delay:.2f} seconds...")
            if i < max_retries - 1:
                time.sleep(delay)
                continue
            else:
                raise e

    return result


class OpenAIHandler:
    base_url: str = None
    api_version: str = None
    api_key: str = None
    model: str = "gpt-4.1"

    client: AzureOpenAI | OpenAI = None

    def __init__(
        self,
        model: str,
        base_url: str | None = None,
        api_version: str | None = None,
        api_key: str | None = None,
    ):
        self.model = model
        self.base_url = base_url
        self.api_version = api_version
        self.api_key = api_key

        if base_url and api_version and "azure.com" in base_url:
            self.client = AzureOpenAI(
                azure_endpoint=base_url,
                api_version=api_version,
                api_key=api_key,
            )
        else:
            self.client = OpenAI(
                api_key=api_key,
                base_url=base_url,
            )

    def call(
        self,
        messages: list[dict],
        tools: list[dict[str, Any]] | None,
        max_tokens: int,
        temperature: float,
        stream: bool,
        top_p: float,
        timeout: float,
        max_retries: int,
    ) -> Message:
        if stream and tools:
            logger.warning("Streaming is not supported for tool calls. Disabling streaming.")
            stream = False

        _config = {
            "model": self.model,
            "messages": messages,
            "max_tokens": max_tokens,
            "temperature": temperature,
            "top_p": top_p,
            "timeout": timeout,
        }
        if tools:
            _config["tools"] = tools

        # Add Qwen3 specific configuration if needed
        if "qwen3" in self.model.lower():
            enable_thinking = os.getenv("ENABLE_THINKING", "True") != "False"
            _config["extra_body"] = {
                "chat_template_kwargs": {"enable_thinking": enable_thinking},
                "separate_reasoning": enable_thinking,
            }

        if not stream:
            try:
                response: ChatCompletion = call_llm_with_retries(
                    self.client.chat.completions.create, max_retries, **_config
                )
            except Exception:
                print("Error occurred while calling LLM, see logger for the payload.")
                logger.warning("Error occurred while calling LLM, below is the payload:")
                logger.warning(json.dumps(_config, indent=2))
                return None

            if not response.choices or not response.choices[0].message:
                raise ValueError(f"Invalid response from LLM {_config['model']}")
            _message: ChatCompletionMessage = response.choices[0].message
            response_content = _message.content
            tool_calls = _message.tool_calls
            return Message(role=Role.ASSISTANT, content=response_content, tool_calls=tool_calls)
        else:
            _config["stream"] = True
            response = self.client.chat.completions.create(**_config)
            chunks = []

            for chunk in response:
                chunk_text = chunk.choices[0].delta.content or ""
                chunks.append(chunk_text)
                print(chunk_text, end="", flush=True)
            response_content = "".join(chunks)
            return Message(role=Role.ASSISTANT, content=response_content)


# https://docs.anthropic.com/en/api/openai-sdk
# https://docs.anthropic.com/en/docs/agents-and-tools/tool-use/implement-tool-use#chain-of-thought
class CapiHandler:
    def __init__(
        self,
        model: str,
    ):
        self.model = model

    def call(
        self,
        messages: list[dict],
        tools: list[dict[str, Any]] | None,
        max_tokens: int,
        temperature: float,
        stream: bool,
        top_p: float,
        timeout: float,
        max_retries: int,
    ) -> Message:
        try:
            from .capi import call_litellm

            #kwargs = {}
            #if "gpt-5" in self.model:
            #    kwargs["reasoning_effort"] = "medium"

            response = call_litellm(
                self.model,
                messages=messages,
                tools=tools,
                max_tokens=max_tokens,
                temperature=temperature,
                stream=stream,
                top_p=top_p,
                timeout=timeout,
                max_retries=max_retries,
                retry_strategy="exponential_backoff_retry",
                #**kwargs,
            )
        except Exception as e:
            logger.error(f"Error calling Anthropic LLM: {e}")
            return None

        # Use the last choice
        _message: ChatCompletionMessage = response.choices[-1].message
        if len(response.choices) <= 2:
            # Merge content and tool_calls (even though n=1)
            response_content = response.choices[0].message.content
        else:
            logger.error(f"An unexpected response from Anthropic LLM: {response}")
            return None

        tool_calls = None
        if _message.tool_calls:
            tool_calls = [
                ChatCompletionMessageToolCall(
                    id=call.id,
                    function=Function(arguments=call.function.arguments, name=call.function.name),
                    type=call.type,
                )
                for call in _message.tool_calls
            ]
        usage = response.usage
        latency = response._response_ms
        actual_model = response.model
        created_time = response.created

        if not tool_calls and not response_content:
            logger.error(f"Invalid response from LLM {self.model}, probably due to an API error")
            return None

        return Message(
            role=Role.ASSISTANT,
            content=response_content,
            tool_calls=tool_calls,
            usage=usage.dict(),
            response_ms=latency,
            actual_model=actual_model,
            created=created_time,
        )


# A temporary solution to https://github.com/BerriAI/litellm/issues/9792
litellm.litellm_core_utils.logging_callback_manager.LoggingCallbackManager.MAX_CALLBACKS = 10000


class LiteLLMHandler:
    def __init__(self, model: str, provider: str, **kwargs):
        self.model = model
        self.provider = provider
        use_router: bool = kwargs.pop("use_router", False)
        if use_router:
            with open("../../.router.json") as f:
                model_list = json.load(f)[model]
            self.router = litellm.Router(model_list)
            self.completion_fn = self.router.completion

        else:
            self.completion_fn = litellm.completion
            # Use OpenAI as default provider
            self.model = f"{provider}/{model}"
            # self.api_key = os.getenv("DASHSCOPE_API_KEY")
            # self.base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
            self.api_key = os.getenv("API_KEY", None)
            self.base_url = os.getenv("BASE_URL", None)

    def call(
        self,
        messages: list[dict],
        tools: list[dict[str, Any]] | None,
        max_tokens: int,
        temperature: float,
        stream: bool,
        top_p: float,
        timeout: float,
        max_retries: int,
    ):
        kwargs = {}
        if hasattr(self, "api_key"):
            kwargs["api_key"] = self.api_key
        if hasattr(self, "base_url"):
            kwargs["base_url"] = self.base_url

        response = self.completion_fn(
            model=self.model,
            messages=messages,
            tools=tools,
            max_tokens=max_tokens,
            temperature=temperature,
            stream=stream,
            top_p=top_p,
            timeout=timeout,
            max_retries=max_retries,
            retry_strategy="exponential_backoff_retry",
            **kwargs,
        )

        # Use fields in ChatCompletionMessage
        _message: ChatCompletionMessage = response.choices[0].message
        response_content = _message.content
        tool_calls = None

        if _message.tool_calls:
            tool_calls = [
                ChatCompletionMessageToolCall(
                    id=call.id,
                    function=Function(arguments=call.function.arguments, name=call.function.name),
                    type=call.type,
                )
                for call in _message.tool_calls
            ]

        usage = response.usage
        latency = response._response_ms
        actual_model = response.model
        created_time = response.created

        if not tool_calls and not response_content:
            logger.error(f"Invalid response from LLM {self.model}, probably due to an API error")
            return None

        return Message(
            role=Role.ASSISTANT,
            content=response_content,
            tool_calls=tool_calls,
            usage=usage.dict(),
            response_ms=latency,
            actual_model=actual_model,
            created=created_time,
        )


class LLM:
    model: str = None
    provider: str = None
    llm_sdk: str = None
    llm_handler: CapiHandler | OpenAIHandler | LiteLLMHandler = None

    def __init__(self, model: str, **kwargs):
        self.model = model
        self.provider = kwargs.pop("provider", "openai") or os.getenv("PROVIDER") # openai, anthropic, azure
        self.llm_sdk = kwargs.pop("llm_sdk", "openai") or os.getenv("LLM_SDK")  # openai, capi, litellm

        if "claude" in model or "gemini" in model:
            self.llm_sdk = "capi"

        if self.llm_sdk == "openai":
            self.llm_handler = OpenAIHandler(
                model=model,
                base_url=os.getenv("BASE_URL", None),
                api_version=os.getenv("API_VERSION", None),
                api_key=os.getenv("API_KEY", None),
            )
        elif self.llm_sdk == "capi":
            self.llm_handler = CapiHandler(model=model)
        elif self.llm_sdk == "litellm":
            self.llm_handler = LiteLLMHandler(model=model, provider=self.provider, **kwargs)
        else:
            raise ValueError(f"Unsupported LLM SDK {self.llm_sdk}")

    def call(
        self,
        messages: list[dict],
        tools: list[dict[str, Any]] | None = None,
        max_tokens: int = 4096,
        temperature: float = 0.7,
        stream: bool = False,
        top_p: float = 0.95,
        timeout: float = 180,
        max_retries: int = 3,
    ) -> Message:
        return self.llm_handler.call(
            messages=messages,
            tools=tools,
            max_tokens=max_tokens,
            temperature=temperature,
            stream=stream,
            top_p=top_p,
            timeout=timeout,
            max_retries=max_retries,
        )
