import json
import re
import time
from typing import Any, Optional

import litellm
from litellm import completion, completion_cost
from litellm.caching.caching import Cache
from litellm.main import ModelResponse, Usage
from loguru import logger
import requests


from vita.config import (
    DEFAULT_API_URL,
    DEFAULT_HEADERS,
    DEFAULT_LLM_CACHE_TYPE,
    DEFAULT_MAX_RETRIES,
    LLM_CACHE_ENABLED,
    REDIS_CACHE_TTL,
    REDIS_CACHE_VERSION,
    REDIS_HOST,
    REDIS_PASSWORD,
    REDIS_PORT,
    REDIS_PREFIX,
    USE_LANGFUSE,
)
from vita.data_model.message import (
    AssistantMessage,
    Message,
    SystemMessage,
    ToolCall,
    ToolMessage,
    UserMessage,
)
from vita.environment.tool import Tool


if USE_LANGFUSE:
    litellm.success_callback = ["langfuse"]
    litellm.failure_callback = ["langfuse"]

# litellm.drop_params = True

if LLM_CACHE_ENABLED:
    if DEFAULT_LLM_CACHE_TYPE == "redis":
        logger.info(f"LiteLLM: Using Redis cache at {REDIS_HOST}:{REDIS_PORT}")
        litellm.cache = Cache(
            type=DEFAULT_LLM_CACHE_TYPE,
            host=REDIS_HOST,
            port=REDIS_PORT,
            password=REDIS_PASSWORD,
            namespace=f"{REDIS_PREFIX}:{REDIS_CACHE_VERSION}:litellm",
            ttl=REDIS_CACHE_TTL,
        )
    elif DEFAULT_LLM_CACHE_TYPE == "local":
        logger.info("LiteLLM: Using local cache")
        litellm.cache = Cache(
            type="local",
            ttl=REDIS_CACHE_TTL,
        )
    else:
        raise ValueError(
            f"Invalid cache type: {DEFAULT_LLM_CACHE_TYPE}. Should be 'redis' or 'local'"
        )
    litellm.enable_cache()
else:
    logger.info("LiteLLM: Cache is disabled")
    litellm.disable_cache()


# This variable is deprecated and will be removed in future versions
# Thinking mode is now controlled by the enable_think parameter
ALLOW_SONNET_THINKING = False

if not ALLOW_SONNET_THINKING:
    logger.warning("Sonnet thinking is disabled (deprecated, use enable_think parameter instead)")


class DictToObject:
    """
    将字典转换为具有属性访问的对象
    使用方法:
    response_obj = DictToObject(response)
    print(response_obj.choices[0].message.content)  # 而不是 response["choices"][0]["message"]["content"]
    """
    def __init__(self, dictionary):
        for key, value in dictionary.items():
            if isinstance(value, dict):
                setattr(self, key, DictToObject(value))
            elif isinstance(value, list):
                setattr(self, key, [DictToObject(item) if isinstance(item, dict) else item for item in value])
            else:
                setattr(self, key, value)

    def to_dict(self):
        """将对象转换回字典"""
        result = {}
        for key, value in self.__dict__.items():
            if isinstance(value, DictToObject):
                result[key] = value.to_dict()
            elif isinstance(value, list):
                result[key] = [item.to_dict() if isinstance(item, DictToObject) else item for item in value]
            else:
                result[key] = value
        return result


def _parse_ft_model_name(model: str) -> str:
    """
    Parse the ft model name from the litellm model name.
    e.g: "ft:gpt-4.1-mini-2025-04-14:sierra::BSQA2TFg" -> "gpt-4.1-mini-2025-04-14"
    """
    pattern = r"ft:(?P<model>[^:]+):(?P<provider>\w+)::(?P<id>\w+)"
    match = re.match(pattern, model)
    if match:
        return match.group("model")
    else:
        return model


def get_response_cost(response: ModelResponse) -> float:
    """
    Get the cost of the response from the litellm completion.
    """
    response.model = _parse_ft_model_name(
        response.model
    )  # FIXME: Check Litellm, passing the model to completion_cost doesn't work.
    hidden_params = getattr(response, "_hidden_params", None)
    custom_llm_provider = hidden_params.get("custom_llm_provider", None)
    if custom_llm_provider == "openai":
        custom_llm_provider = "anthropic"
    hidden_params['custom_llm_provider'] = custom_llm_provider
    setattr(response, "_hidden_params", hidden_params)
    try:
        cost = completion_cost(completion_response=response)
    except Exception as e:
        logger.error(e)
        return 0.0
    return cost


def get_response_usage(response: ModelResponse) -> Optional[dict]:
    usage: Optional[Usage] = response.get("usage")
    if usage is None:
        return None
    return {
        "completion_tokens": usage.completion_tokens,
        "prompt_tokens": usage.prompt_tokens,
    }


def to_vita_messages(
    messages: list[dict], ignore_roles: set[str] = set()
) -> list[Message]:
    """
    Convert a list of messages from a dictionary to a list of vita messages.
    """
    vita_messages = []
    for message in messages:
        role = message["role"]
        if role in ignore_roles:
            continue
        if role == "user":
            vita_messages.append(UserMessage(**message))
        elif role == "assistant":
            vita_messages.append(AssistantMessage(**message))
        elif role == "tool":
            vita_messages.append(ToolMessage(**message))
        elif role == "system":
            vita_messages.append(SystemMessage(**message))
        else:
            raise ValueError(f"Unknown message type: {role}")
    return vita_messages


def to_litellm_messages(messages: list[Message]) -> list[dict]:
    """
    Convert a list of vita messages to a list of litellm messages.
    """
    litellm_messages = []
    for message in messages:
        if isinstance(message, UserMessage):
            litellm_messages.append({"role": "user", "content": message.content})
        elif isinstance(message, AssistantMessage):
            tool_calls = None
            if message.is_tool_call():
                tool_calls = [
                    {
                        "id": tc.id,
                        "name": tc.name,
                        "function": {
                            "name": tc.name,
                            "arguments": json.dumps(tc.arguments),
                        },
                        "type": "function",
                    }
                    for tc in message.tool_calls
                ]
            litellm_messages.append(
                {
                    "role": "assistant",
                    "content": message.content,
                    "tool_calls": tool_calls,
                }
            )
        elif isinstance(message, ToolMessage):
            litellm_messages.append(
                {
                    "role": "tool",
                    "content": message.content,
                    "tool_call_id": message.id,
                    "name": message.name,
                }
            )
        elif isinstance(message, SystemMessage):
            litellm_messages.append({"role": "system", "content": message.content})
    return litellm_messages


def to_claude_think(litellm_messages: list[dict], messages: list[Message]) -> list[dict]:
    """
    Convert a list of litellm messages to a list of claude think messages.
    """
    try:
        litellm_messages_new = []
        for idx, litellm_msg in enumerate(litellm_messages):
            msg_new = {}
            if litellm_msg["role"] == "tool":
                msg_new["role"] = "tool"
                msg_new["tool_call_id"] = litellm_msg["tool_call_id"]
                msg_new["content"] = litellm_msg["content"]
                litellm_messages_new.append(msg_new)
            elif litellm_msg["role"] == "assistant":
                msg_new["role"] = "assistant"
                content = []
                reasoning_content = messages[idx].raw_data["message"].get("reasoning_content", None) or messages[idx].raw_data["message"].get("reasoning", None)
                if reasoning_content:
                    content.append({
                        "type": "thinking",
                        "thinking": reasoning_content,
                        "signature": messages[idx].raw_data["provider_specific_fields"]["reasoning_details"][0][
                            "signature"]
                    })
                else:
                    content.append({
                        "type": "thinking",
                        "thinking": '',
                    })
                if messages[idx].content:
                    content.append({
                        "type": "text",
                        "text": messages[idx].content
                    })
                    msg_new["content"] = content
                if messages[idx].raw_data["message"]["tool_calls"]:
                    if 'content' not in msg_new.keys():
                        msg_new["content"] = [
                            {
                                "type": "thinking",
                                "thinking": '',
                            }
                        ]
                    msg_new['tool_calls'] = []
                    for tool_call in messages[idx].raw_data["message"]["tool_calls"]:
                        msg_new["tool_calls"].append(
                            {
                                "id": tool_call['id'],
                                "type": "function",
                                "function": {
                                    "name": tool_call["function"]["name"],
                                    "arguments": tool_call["function"]["arguments"]
                                }
                            }
                        )
                litellm_messages_new.append(msg_new)
            else:
                litellm_messages_new.append(litellm_msg)

    except Exception as e:
        print(e)
    return litellm_messages_new


def to_claude_think_official(litellm_messages: list[dict], messages: list[Message]) -> list[dict]:
    """
    Convert a list of litellm messages to a list of claude think messages.
    """
    try:
        idx = -2 if litellm_messages[-1]["role"] == "tool" else -1
        content = [
            {
                "type": "text",
                "text": messages[idx].content
            }
        ]
        if messages[idx].raw_data["message"]["tool_calls"]:
            content.append(
                {
                    "type": "tool_use",
                    "id": messages[idx].raw_data["message"]["tool_calls"][0]['id'],
                    "name": messages[idx].raw_data["message"]["tool_calls"][0]["function"]["name"],
                    "input": messages[idx].raw_data["message"]["tool_calls"][0]["function"]["arguments"]
                }
            )
        reasoning_content = messages[idx].raw_data["message"].get("reasoning_content", None) or messages[idx].raw_data["message"].get("reasoning", None)
        if reasoning_content:
            content.append(
                {
                    "type": "thinking",
                    "thinking": reasoning_content
                }
            )

        litellm_messages[idx]["content"] = content
    except Exception as e:
        print(e)

    return litellm_messages


def generate(
    model: str,
    messages: list[Message],
    tools: Optional[list[Tool]] = None,
    tool_choice: Optional[str] = None,
    enable_think: bool = False,
    api_url: str = DEFAULT_API_URL,
    **kwargs: Any,
) -> UserMessage | AssistantMessage:
    """
    Generate a response from the model.

    Args:
        model: The model to use.
        messages: The messages to send to the model.
        tools: The tools to use.
        tool_choice: The tool choice to use.
        enable_think: Whether to enable think mode for the agent.
        **kwargs: Additional arguments to pass to the model.

    Returns: A tuple containing the message and the cost.
    """
    try:
        if kwargs.get("num_retries") is None:
            kwargs["num_retries"] = DEFAULT_MAX_RETRIES
        litellm_messages = to_litellm_messages(messages)
        tools = [tool.openai_schema for tool in tools] if tools else None
        if tools and tool_choice is None:
            tool_choice = "auto"
        try:
            data = {
                "model": model,
                "messages": litellm_messages,
                "stream": False,
                "temperature": kwargs.get("temperature"),
                "tools": tools,
                "tool_choice": tool_choice,
            }
            if model == 'gpt-5':
                data['max_completion_tokens'] = kwargs.get("max_tokens")
            else:
                data['max_tokens'] = kwargs.get("max_tokens")
            # Use default configuration from config.py
            headers = DEFAULT_HEADERS

            # 重试机制：当响应状态码为500时重试
            max_retries = 3
            retry_delay = 1  # 初始延迟1秒

            for attempt in range(max_retries + 1):
                try:
                    response = requests.post(api_url, json=data, headers=headers, timeout=(10, 600))

                    # 如果状态码不是500，直接返回结果
                    if response.status_code != 500:
                        response = response.json()
                        break

                    # 如果是500错误且还有重试机会
                    if attempt < max_retries:
                        logger.warning(f"API返回500错误，第{attempt + 1}次重试，{retry_delay}秒后重试...")
                        time.sleep(retry_delay)
                        retry_delay *= 2  # 指数退避
                    else:
                        # 最后一次重试失败，抛出异常
                        response.raise_for_status()

                except requests.exceptions.RequestException as e:
                    if attempt < max_retries:
                        logger.warning(f"请求异常，第{attempt + 1}次重试，{retry_delay}秒后重试... 错误: {e}")
                        time.sleep(retry_delay)
                        retry_delay *= 2
                    else:
                        raise e
        except Exception as e:
            logger.error(e)
            raise e
        cost = 0
        usage = None
        try:
            response = response['choices'][0]
        except:
            print(f"bad response: {response}")
        assert response['message']['role'] == "assistant", (
            "The response should be an assistant message"
        )
        content = response['message'].get('content')
        tool_calls = response['message'].get('tool_calls') or []
        tool_calls = [
            ToolCall(
                id=tool_call.get('id'),
                name=tool_call.get('function', {}).get('name'),
                arguments=json.loads(tool_call.get('function', {}).get('arguments')) if tool_call.get('function', {}).get('arguments') else {},
            )
            for tool_call in tool_calls
        ]
        tool_calls = tool_calls or None
        message = AssistantMessage(
            role="assistant",
            content=content,
            tool_calls=tool_calls,
            cost=cost,
            usage=usage,
            raw_data=response,
        )
        return message
    except Exception as e:
       logger.error(e)


def get_cost(messages: list[Message]) -> tuple[float, float] | None:
    """
    Get the cost of the interaction between the agent and the user.
    Returns None if any message has no cost.
    """
    agent_cost = 0
    user_cost = 0
    for message in messages:
        if isinstance(message, ToolMessage):
            continue
        if message.cost is not None:
            if isinstance(message, AssistantMessage):
                agent_cost += message.cost
            elif isinstance(message, UserMessage):
                user_cost += message.cost
        else:
            logger.warning(f"Message {message.role}: {message.content} has no cost")
            return None
    return agent_cost, user_cost


def get_token_usage(messages: list[Message]) -> dict:
    """
    Get the token usage of the interaction between the agent and the user.
    """
    usage = {"completion_tokens": 0, "prompt_tokens": 0}
    for message in messages:
        if isinstance(message, ToolMessage):
            continue
        if message.usage is None:
            logger.warning(f"Message {message.role}: {message.content} has no usage")
            continue
        usage["completion_tokens"] += message.usage["completion_tokens"]
        usage["prompt_tokens"] += message.usage["prompt_tokens"]
    return usage
