import json
import os
import traceback
import time
from typing import List, Dict, Optional
from schema import ApiConfig
import logger
from tool.registry import get_openai_tools, find_tool
import utils



config_file = "config.json"


class LLM:
    """LLM API类"""

    def __init__(self, config_name: str = None):
        self.config_name = config_name
        if config_name:
            self.api_config = self.load_api_config(config_name)
        else:
            self.api_config = utils.api_config

    @staticmethod
    def load_api_config(config_name: str) -> ApiConfig:
        """加载 API 配置"""
        with open(config_file, "r", encoding="utf-8") as file:
            data = json.load(file)
        api_configs = [
            ApiConfig.from_dict(config) for config in data.get("api_configs", [])
        ]
        api_config = next((config for config in api_configs if config.config_name.lower() == config_name.lower()), None)
        if api_config is not None:
            return api_config

        # fallback to first available config if any
        if api_configs:
            return api_configs[0]

        return None

    @staticmethod
    def check_api_key(api_key_env: str) -> str:
        """
        检查API_KEY是否设定
        """
        if not api_key_env:
            return ""
        api_key = os.getenv(api_key_env)
        if not api_key:
            raise RuntimeError(
                f"{api_key_env} is not set. Please set the environment variable."
            )
        return api_key

    @staticmethod
    def check_glm_base_url(base_url_env: str = "BASE_HOST") -> str:
        """
        检查GLM的BASE_HOST
        """
        base_url = os.getenv(base_url_env)
        if not base_url:
            # logger.warning(
            #     f"{base_url_env} is not set. Please set the environment variable."
            # )
            return None
        else:
            return f"{base_url}/api/paas/v4/"

    def ask(
        self,
        messages: list[dict],
        tools: list[dict] = [],
        max_retries: int = 3,
    ):
        """
        获得对话结果（带重试机制）

        :param messages: 对话消息
        :param tools: 工具
        :param max_retries: 最大重试次数（默认3次）
        :return: 对话结果
        """
        model = self.api_config.model
        temperature = self.api_config.temperature
        stream = self.api_config.stream
        
        from openai import OpenAI

        if not self.api_config.base_url:
            raise RuntimeError("通用OpenAI接口配置 需要 base_url 参数")

        if self.api_config.type == "GLM":
            base_url = LLM.check_glm_base_url() or self.api_config.base_url
        else:
            base_url = self.api_config.base_url
        client = OpenAI(
            base_url=base_url,
            api_key=LLM.check_api_key(self.api_config.api_key_env),
        )

        last_error = None
        for attempt in range(max_retries):
            try:
                if attempt > 0:
                    # 指数退避：第1次重试等待1秒，第2次等待2秒，第3次等待4秒
                    wait_time = 2 ** (attempt - 1)
                    logger.warning(f"【LLM重试】第 {attempt} 次重试，等待 {wait_time} 秒...")
                    time.sleep(wait_time)

                logger.trace("【请求LLM回答】", str(messages)," 可用工具: ", [str(t['function']['name']) for t in tools])
                kwargs = {"model": model, "messages": messages,"stream": stream, "temperature": temperature}
                if tools:  # only add if not empty
                    kwargs["tools"] = tools
                if self.api_config.enable_thinking:
                    kwargs["extra_body"] = {"enable_thinking": True}
                response = client.chat.completions.create(
                    **kwargs
                )

                logger.trace("【LLM返回结果】", str(response))

                if response.choices[0].finish_reason == "length":
                    logger.warning("【回答长度过长】")

                # 记录 token 消耗（若 API 返回 usage）
                try:
                    if not stream:
                        usage = getattr(response, "usage", None)
                        prompt_tokens = completion_tokens = total_tokens = None
                        if usage is not None:
                            # OpenAI Chat Completions 通常包含 prompt_tokens / completion_tokens / total_tokens
                            # 也兼容可能存在的 input_tokens / output_tokens 命名
                            if hasattr(usage, "prompt_tokens") or (isinstance(usage, dict) and "prompt_tokens" in usage):
                                prompt_tokens = getattr(usage, "prompt_tokens", None)
                                completion_tokens = getattr(usage, "completion_tokens", None)
                                total_tokens = getattr(usage, "total_tokens", None)
                            else:
                                # 兼容 input_tokens/output_tokens
                                if hasattr(usage, "input_tokens") or (isinstance(usage, dict) and "input_tokens" in usage):
                                    prompt_tokens = getattr(usage, "input_tokens", None)
                                if hasattr(usage, "output_tokens") or (isinstance(usage, dict) and "output_tokens" in usage):
                                    completion_tokens = getattr(usage, "output_tokens", None)
                                if hasattr(usage, "total_tokens") or (isinstance(usage, dict) and "total_tokens" in usage):
                                    total_tokens = getattr(usage, "total_tokens", None)
                            logger.log_token_usage(
                                model=model,
                                prompt_tokens=prompt_tokens,
                                completion_tokens=completion_tokens,
                                total_tokens=total_tokens,
                            )
                except Exception as _:
                    # 不影响主流程
                    logger.debug("【LLM Token】统计失败但已忽略。")

                # 请求成功，返回结果
                return response
                
            except Exception as e:
                last_error = e
                logger.warning(f"【请求LLM失败】第 {attempt + 1}/{max_retries} 次尝试: {e}")
                
                # 如果是最后一次尝试，记录完整错误信息
                if attempt == max_retries - 1:
                    logger.error(f"【请求回答出错】已重试 {max_retries} 次均失败: {e}\n{traceback.format_exc()}")
        
        # 所有重试都失败，抛出最后一个错误
        raise last_error

    def ask_with_tools(
        self,
        messages: List[Dict],
        tool_names: Optional[List[str]] = None,
        max_loops: int = 5,
        temperature: Optional[float] = None,
    ):
        """支持函数调用的对话。自动循环执行模型提出的工具调用。

        参数:
            messages: 初始对话消息列表
            tool_names: 允许的工具名称列表(默认为全部已注册工具)
            max_loops: 最多迭代轮次，防止死循环
            temperature: 覆盖默认温度
        返回:
            最终的 OpenAI response 对象 (最后一轮无 tool_call 的回复)
        """
        model = self.api_config.model

        try:
            available_tools = get_openai_tools(allowed_names=tool_names)
            loop = 0
            last_response = None

            while loop < max_loops:
                loop += 1
                # logger.trace(
                #     f"【react循环 第{loop}轮】【请求回答】 {str(messages)} 使用工具: {[t['name'] for t in available_tools]}"
                # )
                response = self.ask(
                    messages=messages,
                    tools=available_tools,
                )
                last_response = response

                msg = response.choices[0].message
                tool_calls = getattr(msg, "tool_calls", None)

                # 无工具调用，结束循环
                if not tool_calls:
                    messages.append({"role": "assistant", "content": msg.content})
                    # logger.trace("【工具循环结束】模型未再提出工具调用。")
                    break
                logger.info(f"【智能体回答】 {msg.content.strip() if msg.content else ''}")
                # 先把 assistant 的带有 tool_calls 的消息补充到消息序列中
                try:
                    messages.append({
                        "role": "assistant",
                        "content": msg.content or "",
                        "tool_calls": [
                            {
                                "id": tc.id,
                                "function": {
                                    "name": tc.function.name,
                                    "arguments": tc.function.arguments,
                                },
                                "type": tc.type,
                            }
                            for tc in tool_calls
                        ],
                    })
                except Exception:
                    # 如果组装失败，不影响工具执行
                    pass

                # 执行每个工具
                for tc in tool_calls:
                    tool_name = tc.function.name
                    raw_args = tc.function.arguments
                    logger.info(f"【执行工具】{tool_name} 原始参数: {raw_args}")
                    try:
                        args = json.loads(raw_args) if raw_args else {}
                    except Exception as e:
                        logger.warning(f"【解析工具参数失败】{tool_name}: {e}")
                        args = {}

                    # 记录工具调用
                    # try:
                    #     record_tool_call(tool_name, args)
                    # except Exception:
                    #     pass

                    tool_impl = find_tool(tool_name)
                    if tool_impl is None:
                        logger.warning(f"【工具未注册】{tool_name}，返回错误。")
                        tool_output_dict = {"error": f"tool {tool_name} not registered"}
                    else:
                        try:
                            result = tool_impl.execute(**args)
                            tool_output_dict = result.to_dict()
                        except Exception as e:
                            logger.error(
                                f"【工具执行异常】{tool_name}: {e}\n{traceback.format_exc()}"
                            )
                            tool_output_dict = {"error": f"exception: {e}"}

                    tool_content = json.dumps(tool_output_dict, ensure_ascii=False)
                    messages.append(
                        {
                            "role": "tool",
                            "tool_call_id": tc.id,
                            "name": tool_name,
                            "content": tool_content,
                        }
                    )
                    if tool_name !="run_datascience":
                        logger.info(f"【工具结果】{tool_name}: {tool_content}")
                    # # 记录工具结果
                    # try:
                    #     record_tool_result(tool_name, tool_output_dict)
                    # except Exception:
                    #     pass
                
            logger.info("【回答结果】" + str(last_response.choices[0].message.content))
            return last_response
        except Exception as e:
            logger.error(f"【请求回答(工具)出错】: {e}\n{traceback.format_exc()}")
            raise e
