import openai
from openai import OpenAI
import numpy as np
import pickle
import json
import time
from core.model.consts import response_format


class LLMApi:
    def __init__(self, model_type="doubao"):
        """
        初始化LLM API客户端

        Args:
            model_type (str): 模型类型，支持 "doubao", "azure_gpt", "azure_gpt_thinking"
        """
        self.model_type = model_type

        with open("core/infra/config.json", "r") as f:
            config = json.load(f)

        # Doubao配置
        self.api_key = config["api_key"]
        self.base_url = config["base_url"]
        self.model_ep = config["model_ep"]
        self.embedding_model_ep = config["embedding_model_ep"]

        # Azure GPT配置
        self.azure_endpoint = config["azure"]["endpoint"]
        self.azure_api_version = config["azure"]["api_version"]
        self.azure_api_key = config["azure"]["api_key"]
        self.azure_model_name = config["azure"]["model_name"]
        self.max_tokens = 1000

        # 缓存和客户端
        self.generate_cache = {}
        self.embedding_cache = {}
        self.doubao_client = None
        self.azure_client = None

        self.load_cache()

    def get_doubao_client(self):
        """获取Doubao客户端"""
        if self.doubao_client is None:
            self.doubao_client = OpenAI(
                api_key=self.api_key,
                base_url=self.base_url,
                timeout=120,
                max_retries=5
            )
        return self.doubao_client

    def get_azure_client(self):
        """获取Azure OpenAI客户端"""
        if self.azure_client is None:
            self.azure_client = openai.AzureOpenAI(
                azure_endpoint=self.azure_endpoint,
                api_version=self.azure_api_version,
                api_key=self.azure_api_key,
            )
        return self.azure_client

    @staticmethod
    def truncate_text(text, max_chars=100000):
        return text[:max_chars] if len(text) > max_chars else text

    def load_cache(self):
        """加载缓存"""
        try:
            with open("embedding_cache.pkl", "rb") as f:
                self.embedding_cache = pickle.load(f)
        except FileNotFoundError:
            self.embedding_cache = {}

        try:
            with open("generate_cache.pkl", "rb") as f:
                self.generate_cache = pickle.load(f)
        except FileNotFoundError:
            self.generate_cache = {}

    def save_cache(self):
        """保存缓存"""
        with open("embedding_cache.pkl", "wb") as f:
            pickle.dump(self.embedding_cache, f)
        with open("generate_cache.pkl", "wb") as f:
            pickle.dump(self.generate_cache, f)

    def _generate_doubao(self, messages, temperature=0.1, max_attempts=3, sleep_time=10):
        """Doubao模型生成"""
        attempts = 0
        messages = [{"role": m["role"], "content": self.truncate_text(m["content"])} for m in messages]

        while attempts < max_attempts:
            try:
                client = self.get_doubao_client()
                response = client.beta.chat.completions.parse(
                    model=self.model_ep,
                    messages=messages,
                    temperature=temperature,
                    # FIXME: diff下面两行不一致
                    response_format=response_format,  # 指定响应解析模型
                    extra_body={
                        "thinking": {
                            "type": "enabled"  # 使用深度思考能力 "disabled"-不使用深度思考能力
                        }
                    }
                )
                # FIXME diff: return response.choices[0].message.content.strip()
                response_content = response.choices[0].message.content
                result_data = json.loads(response_content)
                result_number = result_data["final_result"]
                return str(result_number)
            except Exception as e:
                attempts += 1
                if attempts < max_attempts:
                    time.sleep(sleep_time)
                else:
                    print(f"Doubao API调用失败 after {max_attempts} attempts: {e}")
                    return None
        return None

    def _generate_azure_gpt(self, messages, max_attempts=3, sleep_time=10, include_thinking=False):
        """Azure GPT模型生成"""
        attempts = 0

        # 转换消息格式
        if isinstance(messages, str):
            # 如果是字符串，转换为消息格式
            azure_messages = [{
                "role": "user",
                "content": [{"type": "text", "text": messages}]
            }]
        elif isinstance(messages, list) and len(messages) > 0:
            # 如果是消息列表，取最后一条用户消息
            user_content = None
            for msg in reversed(messages):
                if msg.get("role") == "user":
                    user_content = msg.get("content", "")
                    break

            if user_content is None:
                user_content = str(messages)

            azure_messages = [{
                "role": "user",
                "content": [{"type": "text", "text": user_content}]
            }]
        else:
            azure_messages = [{
                "role": "user",
                "content": [{"type": "text", "text": str(messages)}]
            }]

        while attempts < max_attempts:
            try:
                client = self.get_azure_client()

                # 构建请求参数
                create_params = {
                    "model": self.azure_model_name,
                    "messages": azure_messages,
                    "max_tokens": self.max_tokens,
                    "extra_headers": {"X-TT-LOGID": ""}
                }

                # 如果需要thinking模式
                if include_thinking:
                    create_params["extra_body"] = {"thinking": {"include_thoughts": True}}

                completion = client.chat.completions.create(**create_params)
                return completion.choices[0].message.content

            except Exception as e:
                attempts += 1
                if attempts < max_attempts:
                    time.sleep(sleep_time)
                else:
                    print(f"Azure GPT API调用失败 after {max_attempts} attempts: {e}")
                    return None
        return None

    def generate(self, messages, temperature=0.1, max_attempts=3, sleep_time=10):
        """统一的生成接口"""
        # 生成缓存键
        cache_key = json.dumps({
            "model_type": self.model_type,
            "messages": messages,
            "temperature": temperature
        }, sort_keys=True)

        if cache_key in self.generate_cache:
            return self.generate_cache[cache_key]

        # 根据模型类型调用对应方法
        if self.model_type == "doubao":
            result = self._generate_doubao(messages, temperature, max_attempts, sleep_time)
        elif self.model_type == "azure_gpt":
            result = self._generate_azure_gpt(messages, max_attempts, sleep_time, include_thinking=False)
        elif self.model_type == "azure_gpt_thinking":
            result = self._generate_azure_gpt(messages, max_attempts, sleep_time, include_thinking=True)
        else:
            raise ValueError(f"不支持的模型类型: {self.model_type}")

        # 缓存结果
        if result is not None:
            self.generate_cache[cache_key] = result
            self.save_cache()

        return result

    def chat(self, question, temperature=0.1, max_attempts=3, sleep_time=10):
        """简化的聊天接口"""
        if self.model_type == "doubao":
            messages = [{"role": "user", "content": question}]
            return self.generate(messages, temperature, max_attempts, sleep_time)
        else:
            # Azure GPT模式
            return self.generate(question, temperature, max_attempts, sleep_time)

    def get_embedding(self, text, max_attempts=3, sleep_time=20):
        """获取词向量嵌入（仅支持Doubao）"""
        if self.model_type != "doubao":
            raise ValueError("词向量嵌入功能仅支持Doubao模型")

        text = self.truncate_text(text)
        if text in self.embedding_cache:
            return self.embedding_cache[text]

        for attempt in range(max_attempts):
            try:
                client = self.get_doubao_client()
                resp = client.embeddings.create(
                    model=self.embedding_model_ep,
                    input=[text]
                )
                embedding_list = resp.data[0].embedding

                try:
                    embedding = np.array(embedding_list, dtype=np.float32)
                    if len(embedding) != 2560:
                        print(f"Warning: Embedding length is {len(embedding)} for text: {text[:50]}..., expected 2560. Using zeros.")
                        embedding = np.zeros(2560, dtype=np.float32)

                    self.embedding_cache[text] = embedding
                    self.save_cache()
                    return embedding

                except ValueError as ve:
                    print(f"Failed to convert embedding to numpy array: {ve}")

            except Exception as e:
                print(f"Embedding API call failed: {e}")
                if attempt < max_attempts - 1:
                    time.sleep(sleep_time)
                else:
                    print(f"Embedding API调用失败 after {max_attempts} attempts: {e}")
                    return np.zeros(2560, dtype=np.float32)

        print(f"Embedding API调用失败 after {max_attempts} attempts")
        return np.zeros(2560, dtype=np.float32)