#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
LLM工厂类和管理器
提供统一的接口来创建和管理不同类型的LLM实例
"""

import yaml
import os
from typing import Dict, Any, Optional
import logging

from .base_llm import BaseLLM
from .openai_llm import OpenAICompatibleLLM

logger = logging.getLogger(__name__)


class LLMFactory:
    """LLM工厂类"""
    
    _instance = None
    _config = None
    _chat_templates = None
    
    def __new__(cls, config_path: Optional[str] = None):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
            cls._instance._initialized = False
        return cls._instance
    
    def __init__(self, config_path: Optional[str] = None):
        if not hasattr(self, '_initialized') or not self._initialized:
            self.load_config(config_path)
            self._initialized = True
    
    def load_config(self, config_path: Optional[str] = None):
        """加载配置文件"""
        if config_path is None:
            # 默认配置文件路径 - 指向engineered_pipeline内的配置
            config_path = os.path.join(
                os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
                "configs", "models.yaml"
            )
        
        if not os.path.exists(config_path):
            # 如果配置文件不存在，使用默认配置
            logger.warning(f"配置文件不存在: {config_path}，使用默认配置")
            return {
                "llm_providers": {
                    "openai": {
                        "provider": "openai",
                        "models": []
                    }
                }
            }
        
        with open(config_path, 'r', encoding='utf-8') as f:
            config = yaml.safe_load(f)
        
        self._config = config
        self._chat_templates = config.get("chat_templates", {})
        
        # logger.info(f"已加载配置文件: {config_path}")
        # logger.info(f"可用模型: {list(config.get('models', {}).keys())}")
    
    def get_available_models(self) -> list:
        """获取可用的模型列表"""
        if self._config is None:
            raise RuntimeError("配置未加载，请先调用load_config()")
        
        return list(self._config.get("models", {}).keys())
    
    def get_default_model(self) -> str:
        """获取默认模型名称"""
        if self._config is None:
            raise RuntimeError("配置未加载，请先调用load_config()")
        
        return self._config.get("default_model", "")
    
    def create_llm(self, model_name: Optional[str] = None) -> BaseLLM:
        """创建LLM实例"""
        if self._config is None:
            raise RuntimeError("配置未加载，请先调用load_config()")
        
        # 使用指定模型或默认模型
        if model_name is None:
            model_name = self.get_default_model()
        
        if model_name not in self._config["models"]:
            available_models = self.get_available_models()
            raise ValueError(f"未知的模型: {model_name}. 可用模型: {available_models}")
        
        model_config = self._config["models"][model_name].copy()
        model_type = model_config.get("type", "openai").lower()
        
        # 解析相对路径
        model_config = self._resolve_paths(model_config)
        
        logger.info(f"创建{model_type}类型的LLM: {model_name}")
        
        if model_type in ("openai", "vllm_server", "openai_server"):
            return OpenAICompatibleLLM(model_config)
        else:
            raise ValueError(f"不支持的模型类型: {model_type}，仅支持 openai")
    
    def _resolve_paths(self, config: Dict[str, Any]) -> Dict[str, Any]:
        """解析相对路径为绝对路径"""
        base_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
        
        for key in ["model_path", "tokenizer_path"]:
            if key in config and isinstance(config[key], str):
                path = config[key]
                if path.startswith("./"):
                    config[key] = os.path.join(base_dir, path[2:])
                elif path.startswith("~/"):
                    config[key] = os.path.expanduser(path)
        
        return config
    
    def get_chat_templates(self) -> Dict[str, Dict[str, str]]:
        """获取聊天模板"""
        if self._config is None:
            raise RuntimeError("配置未加载，请先调用load_config()")
        
        return self._chat_templates


class LLMManager:
    """LLM管理器，提供高级接口"""
    
    def __init__(self, config_path: Optional[str] = None):
        self.factory = LLMFactory(config_path)
        self.current_llm: Optional[BaseLLM] = None
        self.current_model_name: Optional[str] = None
        
        # 尝试初始化默认模型（安全模式）
        try:
            available_models = self.list_models()
            if available_models:
                default_model = available_models[0]
                self.switch_model(default_model)
                logger.info(f"已初始化默认模型: {default_model}")
        except Exception as e:
            logger.warning(f"初始化默认模型失败: {e}")
            # 确保属性存在
            if not hasattr(self, 'current_llm'):
                self.current_llm = None
            if not hasattr(self, 'current_model_name'):
                self.current_model_name = None
    
    def list_models(self) -> list:
        """列出所有可用的模型"""
        return self.factory.get_available_models()
    
    def switch_model(self, model_name: str) -> bool:
        """切换到指定模型"""
        if self.current_model_name == model_name:
            # logger.info(f"当前已经是模型 {model_name}")
            return True
        
        try:
            # logger.info(f"切换模型: {self.current_model_name} -> {model_name}")
            
            # 释放当前模型资源
            if self.current_llm is not None:
                del self.current_llm
            
            # 创建新模型
            self.current_llm = self.factory.create_llm(model_name)
            self.current_model_name = model_name
            
            # logger.info(f"模型切换完成: {model_name}")
            return True
        except Exception as e:
            logger.error(f"模型切换失败: {e}")
            return False
    
    def generate(self, prompt: str, **kwargs) -> str:
        """生成文本"""
        if self.current_llm is None:
            # 使用默认模型
            default_model = self.factory.get_default_model()
            self.switch_model(default_model)
        
        return self.current_llm.generate(prompt, **kwargs)
    
    def batch_generate(self, prompts: list, **kwargs) -> list:
        """批量生成文本"""
        if self.current_llm is None:
            default_model = self.factory.get_default_model()
            self.switch_model(default_model)
        
        return self.current_llm.batch_generate(prompts, **kwargs)
    
    def chat(self, messages: list, **kwargs) -> str:
        """对话生成"""
        if self.current_llm is None:
            default_model = self.factory.get_default_model()
            self.switch_model(default_model)
        
        chat_templates = self.factory.get_chat_templates()
        return self.current_llm.chat(messages, chat_templates, **kwargs)
    
    def get_current_model(self) -> Optional[str]:
        """获取当前模型名称"""
        return self.current_model_name
    
    def get_model_info(self) -> Optional[Dict[str, Any]]:
        """获取当前模型信息"""
        if self.current_llm is None:
            return None
        
        return {
            "model_name": self.current_model_name,
            "model_path": self.current_llm.model_path,
            "max_tokens": self.current_llm.max_tokens,
            "temperature": self.current_llm.temperature,
            "top_p": self.current_llm.top_p,
            "repetition_penalty": self.current_llm.repetition_penalty,
            "chat_template": self.current_llm.chat_template,
        }
