from openai import OpenAI
import time
import json
import os
import sys

sys.path.append("./model_api")
from BaseAPI import BaseAPI


class MistralLocalAPI(BaseAPI):
    def __init__(self, model_name, generation_config={}):
        super().__init__(generation_config)
        self.model_name = model_name
        # 支持通过环境变量配置 base_url 和 api_key，默认使用本地 7010 端口
        port = os.environ.get("MISTRAL_PORT", 7001)
        base_url = f"http://localhost:{port}/v1"
        api_key = "forward"
        self.client = OpenAI(
            base_url=base_url,
            api_key=api_key,
        )
        self.sys_prompt = self.without_strict_jsonformat_sys_prompt

    def response(self, messages, tools):
        if not tools:
            tools = None
        
        # 验证和清理 messages 格式
        cleaned_messages = self._clean_messages(messages)
        
        # 合并生成配置，添加防死循环参数
        config = self.generation_config.copy()
        # 确保 max_tokens 有上限，防止死循环
        if 'max_tokens' not in config:
            config['max_tokens'] = 2048
        else:
            config['max_tokens'] = min(config.get('max_tokens', 2048), 4096)
        
        # 移除可能不支持的参数（某些 OpenAI 兼容 API 不支持这些参数）
        # repetition_penalty 不是 OpenAI API 的标准参数，某些实现不支持
        unsupported_params = ['repetition_penalty']
        for param in unsupported_params:
            if param in config:
                config.pop(param)
        
        for attempt in range(10):
            try:
                completion = self.client.chat.completions.create(
                    model=self.model_name,
                    tools=tools,
                    messages=cleaned_messages,
                    **config
                )
                if completion is None or completion.choices is None:
                    continue
                return completion
            except Exception as e:
                error_str = str(e).lower()
                # 检查是否是 JSON 解析错误
                if 'expecting' in error_str and 'delimiter' in error_str:
                    print(f"Warning: JSON parsing error detected, attempting to fix messages format...")
                    # 尝试进一步清理 messages
                    cleaned_messages = self._clean_messages(messages, strict=True)
                    if attempt < 9:
                        time.sleep(1)
                        continue
                
                # 检查是否是参数不支持的错误
                if 'unexpected keyword argument' in error_str or 'invalid' in error_str or 'not supported' in error_str:
                    # 尝试识别不支持的参数
                    if 'repetition_penalty' in error_str:
                        if 'repetition_penalty' in config:
                            config.pop('repetition_penalty')
                            print(f"Warning: repetition_penalty not supported, removed from config")
                            continue  # 重试
                    # 其他不支持的参数
                    for param in list(config.keys()):
                        if param in error_str and param not in ['model', 'tools', 'messages', 'temperature', 'max_tokens', 'top_p', 'top_k']:
                            print(f"Warning: Parameter '{param}' may not be supported, removing...")
                            config.pop(param, None)
                            continue  # 重试
                
                # 检查是否是 content_filter 错误（Azure OpenAI / 360 API 的内容过滤）
                if 'content_filter' in error_str or 'content_management' in error_str or 'responsibleaipolicyviolation' in error_str or 'jailbreak' in error_str:
                    print(f"⚠️ Content filter triggered: {str(e)[:200]}...")
                    return "CONTENT_FILTERED"
                
                if attempt < 9:  # 不是最后一次尝试
                    print(f"Error in MistralLocalAPI.response (attempt {attempt + 1}/10): {e}")
                    time.sleep(1)
                else:
                    print(f"Failed after 10 attempts: {e}")
                    raise
    
    def _clean_messages(self, messages, strict=False):
        """
        清理和验证 messages 格式，确保符合 OpenAI API 规范
        
        Args:
            messages: 原始消息列表
            strict: 是否使用严格模式（更彻底的清理）
        
        Returns:
            清理后的消息列表
        """
        cleaned = []
        for msg in messages:
            cleaned_msg = {}
            
            # 确保 role 字段存在且有效
            role = msg.get("role")
            if not role or role not in ["system", "user", "assistant", "tool"]:
                # 跳过无效的 role
                continue
            
            cleaned_msg["role"] = role
            
            # 处理 content 字段
            content = msg.get("content")
            if content is None:
                # 对于 assistant 消息，如果没有 content，设置为 None（如果有 tool_calls）
                if role == "assistant" and msg.get("tool_calls"):
                    cleaned_msg["content"] = None
                else:
                    cleaned_msg["content"] = ""
            else:
                # 确保 content 是字符串
                if isinstance(content, (dict, list)):
                    # 如果是字典或列表，转换为 JSON 字符串
                    try:
                        cleaned_msg["content"] = json.dumps(content, ensure_ascii=False)
                    except:
                        cleaned_msg["content"] = str(content)
                else:
                    cleaned_msg["content"] = str(content)
            
            # 处理 tool_calls 字段（仅 assistant 消息）
            if role == "assistant" and "tool_calls" in msg:
                tool_calls = msg["tool_calls"]
                if tool_calls and len(tool_calls) > 0:
                    cleaned_tool_calls = []
                    for tc in tool_calls:
                        cleaned_tc = {}
                        if "id" in tc:
                            cleaned_tc["id"] = tc["id"]
                        if "type" in tc:
                            cleaned_tc["type"] = tc["type"]
                        if "function" in tc:
                            func = tc["function"]
                            cleaned_func = {}
                            if "name" in func:
                                cleaned_func["name"] = func["name"]
                            
                            # 处理 arguments：确保是有效的 JSON 字符串
                            if "arguments" in func:
                                args = func["arguments"]
                                if isinstance(args, dict):
                                    # 如果已经是字典，转换为 JSON 字符串
                                    cleaned_func["arguments"] = json.dumps(args, ensure_ascii=False)
                                elif isinstance(args, str):
                                    # 如果是字符串，验证是否为有效 JSON
                                    try:
                                        json.loads(args)  # 验证
                                        cleaned_func["arguments"] = args
                                    except json.JSONDecodeError:
                                        # 如果无效，尝试修复
                                        try:
                                            # 移除尾随逗号
                                            fixed_args = args.rstrip().rstrip(',')
                                            json.loads(fixed_args)  # 再次验证
                                            cleaned_func["arguments"] = fixed_args
                                        except:
                                            # 如果修复失败，使用空字典
                                            cleaned_func["arguments"] = "{}"
                                else:
                                    cleaned_func["arguments"] = "{}"
                            else:
                                cleaned_func["arguments"] = "{}"
                            
                            cleaned_tc["function"] = cleaned_func
                        cleaned_tool_calls.append(cleaned_tc)
                    
                    cleaned_msg["tool_calls"] = cleaned_tool_calls
                    # 如果有 tool_calls，确保 content 为 None
                    cleaned_msg["content"] = None
            
            # 处理 tool_call_id（仅 tool 消息）
            if role == "tool" and "tool_call_id" in msg:
                cleaned_msg["tool_call_id"] = msg["tool_call_id"]
            
            cleaned.append(cleaned_msg)
        
        return cleaned

    def generate_response(self, messages, tools):
        completion = self.response(messages, tools)

        if completion is None:
            return None
        
        # 处理 content_filter 错误：返回一个标准的拒绝响应（符合 OpenAI assistant 消息格式）
        if completion == "CONTENT_FILTERED":
            refuse_response = (
                "I'm sorry, but I cannot assist with this request. "
                "The content has been flagged by the safety system as potentially harmful or inappropriate. "
                "Please modify your request to comply with safety guidelines."
            )
            return {"type": "content", "content": refuse_response}

        ## tool call part
        if completion.choices[0].message.tool_calls is not None and len(completion.choices[0].message.tool_calls) > 0:
            tool_call = completion.choices[0].message.tool_calls[0]
            tool_call_id = tool_call.id
            tool_name = tool_call.function.name
            arguments = {}
            if tool_call.function.arguments:
                try:
                    # 尝试解析 JSON
                    if isinstance(tool_call.function.arguments, str):
                        arguments = json.loads(tool_call.function.arguments)
                    elif isinstance(tool_call.function.arguments, dict):
                        # 如果已经是字典，直接使用
                        arguments = tool_call.function.arguments
                    else:
                        arguments = {}
                except json.JSONDecodeError as e:
                    # JSON 解析失败，尝试修复常见问题
                    print(f"Warning: Failed to parse tool arguments JSON: {e}")
                    print(f"  Raw arguments: {tool_call.function.arguments}")
                    # 尝试修复：移除可能的尾随逗号、修复引号等
                    try:
                        args_str = str(tool_call.function.arguments).strip()
                        # 移除尾随逗号
                        if args_str.endswith(','):
                            args_str = args_str[:-1].rstrip()
                        # 如果不是以 { 开头，尝试添加
                        if not args_str.startswith('{'):
                            args_str = '{' + args_str + '}'
                        arguments = json.loads(args_str)
                    except:
                        # 如果修复失败，使用空字典
                        arguments = {}
            return {
                "type": "tool",
                "tool_call_id": tool_call_id,
                "tool_name": tool_name,
                "arguments": arguments,
            }

        ## normal content part
        else:
            content = completion.choices[0].message.content
            # 检测并处理死循环：如果内容重复超过3次，截断
            if content:
                content = self._detect_and_fix_loop(content)
            return {"type": "content", "content": content}
    
    def _detect_and_fix_loop(self, content, max_repeat=3):
        """
        检测内容中的死循环模式并截断
        
        Args:
            content: 原始内容
            max_repeat: 允许的最大重复次数，超过则截断
        
        Returns:
            处理后的内容
        """
        if not content or len(content) < 20:
            return content
        
        # 检测短句重复（检查最后100个字符是否重复）
        check_length = min(100, len(content))
        last_part = content[-check_length:]
        
        # 尝试找到重复模式（检查2-20字符的重复）
        for pattern_len in range(2, min(20, len(last_part) // 2)):
            pattern = last_part[-pattern_len:]
            # 检查这个模式是否重复出现
            repeat_count = 0
            check_start = len(content) - check_length
            for i in range(check_start, len(content) - pattern_len, pattern_len):
                if content[i:i+pattern_len] == pattern:
                    repeat_count += 1
                else:
                    break
            
            # 如果重复超过阈值，截断到重复开始之前
            if repeat_count >= max_repeat:
                # 找到第一个重复的位置
                first_repeat_pos = content.rfind(pattern, 0, len(content) - pattern_len * max_repeat)
                if first_repeat_pos > 0:
                    # 截断到第一个重复之前，并添加提示
                    truncated = content[:first_repeat_pos].rstrip()
                    if truncated:
                        return truncated + "\n[内容已截断，检测到重复模式]"
                    else:
                        return content[:len(content) - pattern_len * max_repeat].rstrip() + "\n[内容已截断，检测到重复模式]"
        
        return content
