"""
OpenAI API client with retry logic and error handling
"""
import asyncio
import aiohttp
import json
import logging
from typing import List, Dict, Any, Optional
from dataclasses import dataclass

import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from config import OPENAI_CONFIG, API_CONFIGS, MODEL_TO_API_CONFIG

logger = logging.getLogger(__name__)


@dataclass
class APIResponse:
    """API响应数据结构"""
    success: bool
    content: str
    reasoning_content: Optional[str] = None  # 新增：用于存储推理内容（deepseek-r1-250120和gpt-oss-120b）
    error: Optional[str] = None
    retries_used: int = 0
    prompt_tokens: int = 0  # 新增：prompt tokens数量
    completion_tokens: int = 0  # 新增：completion tokens数量
    total_tokens: int = 0  # 新增：总tokens数量


class UnifiedAPIClient:
    """统一API客户端，支持多个API提供商"""
    
    def __init__(self, model: str):
        self.model = model
        
        # 根据模型选择API配置
        if model in MODEL_TO_API_CONFIG:
            provider = MODEL_TO_API_CONFIG[model]
            config = API_CONFIGS[provider]
        else:
            # 向后兼容：如果在OPENAI_CONFIG的models中，使用openai_compatible
            if model in OPENAI_CONFIG["models"]:
                config = API_CONFIGS["openai_compatible"]
            else:
                raise ValueError(f"Unsupported model: {model}. Supported: {list(MODEL_TO_API_CONFIG.keys())}")
        
        self.base_url = config["base_url"]
        self.api_key = config["api_key"]
        self.temperature = config["temperature"]
        self.max_retries = config["max_retries"]
        self.retry_delay = config["retry_delay"]
        
        # 为特定模型设置更长的超时时间
        slow_response_models = ["gpt-oss-120b", "gpt5", "claude-opus-4", "qwen3-235b-a22b-thinking-2507", "grok-4", "deepseek-r1-250120"]
        if model in slow_response_models:
            self.timeout = 1000  # 10分钟超时
            logger.info(f"Using extended timeout ({self.timeout}s) for slow-response model: {model}")
        else:
            self.timeout = 120  # 默认2分钟超时
        
        logger.info(f"Initialized UnifiedAPIClient for model {model} using {self.base_url} (timeout: {self.timeout}s)")
    
    async def chat_completion(self, messages: List[Dict[str, str]], system_prompt: Optional[str] = None) -> APIResponse:
        """发送聊天完成请求"""
        
        # 构建消息列表
        formatted_messages = []
        if system_prompt:
            formatted_messages.append({"role": "system", "content": system_prompt})
        formatted_messages.extend(messages)
        
        payload = {
            "model": self.model,
            "messages": formatted_messages,
            "temperature": self.temperature,
            "max_tokens": 16000
        }
        
        # 为思考模型添加enable_thinking参数
        if self.model in ["qwen3-235b-a22b-thinking-2507", "qwen/qwen3-235b-a22b-thinking-2507"]:
            payload["enable_thinking"] = True
            logger.info(f"Added enable_thinking=True for thinking model: {self.model}")
        
        return await self._call_with_retry(payload)
    
    async def _call_with_retry(self, payload: Dict[str, Any]) -> APIResponse:
        """带重试机制的API调用"""
        last_error = None
        
        for attempt in range(self.max_retries):
            try:
                async with aiohttp.ClientSession() as session:
                    headers = {
                        "Authorization": f"Bearer {self.api_key}",
                        "Content-Type": "application/json"
                    }
                    
                    async with session.post(
                        self.base_url,
                        headers=headers,
                        json=payload,
                        timeout=aiohttp.ClientTimeout(total=self.timeout)
                    ) as response:
                        
                        if response.status == 200:
                            result = await response.json()
                            
                            # 检查API是否返回错误（即使状态码200）
                            if "error" in result:
                                error_info = result["error"]
                                last_error = f"API Error: {error_info.get('code', 'unknown')} - {error_info.get('message', 'No message')}"
                                logger.warning(f"API returned error for {self.model} on attempt {attempt + 1}: {last_error}")
                                continue  # 跳过此次尝试，进行重试
                            
                            message = result.get("choices", [{}])[0].get("message", {})
                            content = message.get("content", "")
                            
                            # 检查是否为需要reasoning_content的特殊模型
                            reasoning_content = None
                            if self.model in ["deepseek-r1-250120", "gpt-oss-120b", "qwen3-235b-a22b-instruct-2507", "qwen3-235b-a22b-thinking-2507", "qwen/qwen3-235b-a22b-thinking-2507"]:
                                reasoning_content = message.get("reasoning_content", None)
                                if reasoning_content:
                                    logger.info(f"Extracted reasoning_content for {self.model}: {len(reasoning_content)} chars")

                            if self.model in ["qwen/qwen3-235b-a22b-thinking-2507", "openai/gpt-oss-120b"]:
                                reasoning_content = message.get("reasoning", None)
                                if reasoning_content:
                                    logger.info(f"Extracted reasoning_content for {self.model}: {len(reasoning_content)} chars")
                            
                            
                            # 提取token使用信息
                            usage = result.get("usage", {})
                            prompt_tokens = usage.get("prompt_tokens", 0)
                            completion_tokens = usage.get("completion_tokens", 0)
                            total_tokens = usage.get("total_tokens", prompt_tokens + completion_tokens)
                            
                            logger.info(f"API call successful for {self.model} on attempt {attempt + 1}, tokens: {prompt_tokens}+{completion_tokens}={total_tokens}")
                            return APIResponse(
                                success=True,
                                content=content,
                                reasoning_content=reasoning_content,
                                retries_used=attempt,
                                prompt_tokens=prompt_tokens,
                                completion_tokens=completion_tokens,
                                total_tokens=total_tokens
                            )
                        else:
                            error_text = await response.text()
                            last_error = f"HTTP {response.status}: {error_text}"
                            logger.warning(f"API call failed for {self.model} on attempt {attempt + 1}: {last_error}")
                            
            except asyncio.TimeoutError:
                last_error = "Request timeout"
                logger.warning(f"API timeout for {self.model} on attempt {attempt + 1}")
            except Exception as e:
                last_error = str(e)
                logger.warning(f"API error for {self.model} on attempt {attempt + 1}: {last_error}")
            
            # 如果不是最后一次尝试，等待后重试
            if attempt < self.max_retries - 1:
                logger.info(f"Waiting {self.retry_delay}s before retry...")
                await asyncio.sleep(self.retry_delay)
        
        # 所有重试都失败了
        logger.error(f"API call failed for {self.model} after {self.max_retries} attempts. Last error: {last_error}")
        return APIResponse(
            success=False,
            content="",
            error=last_error,
            retries_used=self.max_retries,
            prompt_tokens=0,
            completion_tokens=0,
            total_tokens=0
        )
    
    def get_model_name(self) -> str:
        """获取当前使用的模型名称"""
        return self.model


class OpenAIClient(UnifiedAPIClient):
    """OpenAI API客户端，保持向后兼容性"""
    
    def __init__(self, model: str):
        # 使用父类的初始化逻辑，现在支持多个API
        super().__init__(model)
        
        # 验证模型是否支持（保持原有的验证逻辑）
        if model not in OPENAI_CONFIG["models"]:
            raise ValueError(f"Unsupported model: {model}. Supported: {OPENAI_CONFIG['models']}")


# 便利函数：创建API客户端
def create_api_client(model: str) -> UnifiedAPIClient:
    """创建适当的API客户端"""
    return UnifiedAPIClient(model) 