from typing import Callable, List
import time
import os

from langchain.chat_models import ChatOpenAI
from langchain.schema import ChatMessage
import openai


class GPTWrapper:
    def __init__(self, llm_name: str, openai_api_key: str, long_ver: bool, openai_base_url: str, timeout: int = 30):
        self.model_name = llm_name
        self.timeout = timeout  # 添加超时配置
        if long_ver:
            llm_name = 'Qwen3/Qwen3-8B'
        # allow custom OpenAI base URL (e.g., proxies)
        # if openai_base_url:
        #     try:
        #         openai_base_url = os.environ['OPENAI_BASE_URL']
        #         openai_base_url = os.environ['OPENAI_API_BASE']
        #     except Exception:
        #         pass
        #     try:
        #         openai.api_base = openai_base_url
        #     except Exception:
        #         pass

        openai_base_url = os.environ['OPENAI_BASE_URL']
        openai.api_base = openai_base_url
        # Prepare ChatOpenAI kwargs
        # llm_kwargs = {
        #     'model': llm_name,
        #     'temperature': 0.0,
        #     'openai_api_key': openai_api_key,
        #     'request_timeout': timeout,  # 添加请求超时
        # }
        # # Add base URL if provided
        # if openai_base_url:
        #     llm_kwargs['openai_api_base'] = openai_base_url

        self.llm = ChatOpenAI(
            model = llm_name, 
            temperature = 0.0,
            openai_api_key = openai_api_key,
            request_timeout = timeout,
            openai_api_base = openai_base_url,
            # extra_body= {"chat_template_kwargs": {"enable_thinking":False}},
            model_kwargs={
                'extra_body': {"chat_template_kwargs": {"enable_thinking":False}}
            },
            chat_template_kwargs={"enable_thinking":False}
        )

    def __call__(self, messages: List[ChatMessage], stop: List[str] = [], replace_newline: bool = True) -> str:
        kwargs = {}
        if stop != []:
            kwargs['stop'] = stop
        
        # 优化的重试策略：减少重试次数和延迟时间
        max_retries = 3  # 减少重试次数从6次到3次
        base_delay = 1   # 基础延迟时间
        
        for i in range(max_retries):
            try:
                output = self.llm(messages, **kwargs).content.strip('\n').strip()
                print(output)
                break
            except (openai.error.RateLimitError, openai.error.APIConnectionError, openai.error.Timeout) as e:
                if i < max_retries - 1:  # 不是最后一次重试
                    delay = min(5, base_delay + i)  # 线性增长，最大5秒
                    print(f'\nRetrying {i+1}/{max_retries} due to {type(e).__name__}... (waiting {delay}s)')
                    time.sleep(delay)
                else:
                    print(f'\nFinal retry failed due to {type(e).__name__}')
                    raise
            except Exception as e:
                print(f'\nUnexpected error on attempt {i+1}: {e}')
                if i < max_retries - 1:
                    time.sleep(2)  # 意外错误等待2秒
                else:
                    raise
        else:
            raise RuntimeError('Failed to generate response after all retries')

        if replace_newline:
            output = output.replace('\n', '')
        return output

def LLM_CLS(llm_name: str, openai_api_key: str, long_ver: bool, openai_base_url: str, timeout: int = 30) -> Callable:
    if 'gpt' in llm_name or 'deepseek' in llm_name or 'qwen' in llm_name.lower() or 'claude' in llm_name:
        return GPTWrapper(llm_name, openai_api_key, long_ver, openai_base_url, timeout)
    else:
        raise ValueError(f"Unknown LLM model name: {llm_name}")