import time
from typing import Dict, List, Optional, Union, Callable, Literal, Optional, Union
import logging
import openai
import json
from openai import OpenAI
from autogen.agentchat import Agent, UserProxyAgent, ConversableAgent
from termcolor import colored
import Levenshtein

logger = logging.getLogger(__name__)


class GuardAgent(UserProxyAgent):
    def __init__(
            self,
            name: str,
            is_termination_msg: Optional[Callable[[Dict], bool]] = None,
            max_consecutive_auto_reply: Optional[int] = None,
            human_input_mode: Optional[str] = "ALWAYS",
            function_map: Optional[Dict[str, Callable]] = None,
            code_execution_config: Optional[Union[Dict, Literal[False]]] = None,
            default_auto_reply: Optional[Union[str, Dict, None]] = "",
            llm_config: Optional[Union[Dict, Literal[False]]] = False,
            system_message: Optional[Union[str, List]] = "",
            config_list: Optional[List[Dict]] = None,
    ):
        super().__init__(
            name=name,
            system_message=system_message,
            is_termination_msg=is_termination_msg,
            max_consecutive_auto_reply=max_consecutive_auto_reply,
            human_input_mode=human_input_mode,
            function_map=function_map,
            code_execution_config=code_execution_config,
            llm_config=llm_config,
            default_auto_reply=default_auto_reply,
        )
        self.config_list = config_list
        self.user_request = ''
        self.agent_specification = ''
        self.agent_input = ''
        self.agent_output = ''
        self.subtasks = ''
        self.code = ''

    def task_decomposition(self, config, user_request, agent_specification, agent_input, agent_output, Decomposition_Examples):
        # import prompt
        from prompts_guard import Example_Decomposition
        # Returns the related information to the given query.
        patience = 2
        sleep_time = 30
        openai.api_key = config["api_key"]
        engine = config["model"]
        query_message = Example_Decomposition.format(user_request=user_request,
                                                     agent_specification=agent_specification,
                                                     decomposition_examples=Decomposition_Examples,
                                                     agent_input=agent_input,
                                                     agent_output=agent_output)
        from prompts_guard import SYSTEM_PROMPT_DECOMPOSITION
        messages = [{"role": "system", "content": SYSTEM_PROMPT_DECOMPOSITION},
                    {"role": "user", "content": query_message}]
        # 创建 OpenAI 客户端，支持 base_url（用于本地代理或自定义 API）
        client_kwargs = {"api_key": config["api_key"]}
        if "base_url" in config:
            client_kwargs["base_url"] = config["base_url"]
        client = OpenAI(**client_kwargs)
        while patience > 0:
            patience -= 1
            try:
                response = client.chat.completions.create(
                    model=engine,
                    messages=messages,
                    temperature=0,
                    max_tokens=3000,  # 增加 max_tokens 避免截断
                    top_p=0.95,
                    frequency_penalty=0,
                    presence_penalty=0,
                    stop=None)
                prediction = response.choices[0].message.content.strip()
                if prediction != "" and prediction != None:
                    return prediction
            except Exception as e:
                error_msg = str(e)
                error_type = type(e).__name__
                print(f"Error in task_decomposition: {error_type}: {error_msg}")
                # 如果是连接错误，提供更详细的诊断信息
                if "Connection" in error_type or "connection" in error_msg.lower():
                    base_url = config.get("base_url", "not set")
                    print(f"  Attempting to connect to: {base_url}")
                    print(f"  Model: {engine}")
                    print(f"  API Key: {'set' if config.get('api_key') else 'not set'}")
                if sleep_time > 0:
                    time.sleep(sleep_time)
        return "Fail to retrieve related knowledge, please try again later."

    def retrieve_examples(self, agent_input, agent_output):
        levenshtein_dist = {}
        for i in range(len(self.memory)):
            mem_input = self.memory[i]["agent input"]
            mem_output = self.memory[i]["agent output"]
            levenshtein_dist[i] = Levenshtein.distance(agent_input, mem_input) + Levenshtein.distance(agent_output, mem_output)
        levenshtein_dist = sorted(levenshtein_dist.items(), key=lambda x: x[1], reverse=False)
        selected_indexes = [levenshtein_dist[i][0] for i in range(min(self.num_shots, len(levenshtein_dist)))]
        examples = []
        for i in selected_indexes:
            template = "Agent input:\n {}\nAgent output:\n{}\nTask decomposition:\n{}\nGuardrail code:\n{}\n".format(self.memory[i]["agent input"],
                                                                                                               self.memory[i]["agent output"],
                                                                                                               self.memory[i]["subtasks"],
                                                                                                               self.memory[i]["code"])
            examples.append(template)
        examples = '\n'.join(examples)
        return examples

    def generate_init_message(self, **context):

        self.user_request = context["user_request"]
        self.agent_specification = context["agent_specification"]
        self.agent_input = context["agent_input"]
        self.agent_output = context["agent_output"]
        self.agent_task_deco_examples = context["agent_task_deco_examples"]

        # import prompt
        from prompts_guard import GuardAgent_Message_Prompt
        subtasks = self.task_decomposition(self.config_list[0],
                                           self.user_request,
                                           self.agent_specification,
                                           self.agent_input,
                                           self.agent_output,
                                           self.agent_task_deco_examples)
        self.subtasks = subtasks

        examples = self.retrieve_examples(self.agent_input, self.agent_output)

        init_message = GuardAgent_Message_Prompt.format(examples=examples,
                                                        agent_input=self.agent_input,
                                                        agent_output=self.agent_output,
                                                        subtasks=subtasks)
        return init_message

    def send(self, message: Union[Dict, str], recipient: Agent, request_reply: Optional[bool] = None,
             silent: Optional[bool] = False):
        valid = self._append_oai_message(message, "assistant", recipient)
        if valid:
            recipient.receive(message, self, request_reply, silent)
        else:
            raise ValueError(
                "Message can't be converted into a valid ChatCompletion message. Either content or function_call must be provided."
            )

    def initiate_chat(self, recipient: "ConversableAgent", clear_history: Optional[bool] = True,
                      silent: Optional[bool] = False, **context, ):
        self._prepare_chat(recipient, clear_history)
        self.send(self.generate_init_message(**context), recipient, silent=silent)

    def receive(
            self,
            message: Union[Dict, str],
            sender: Agent,
            request_reply: Optional[bool] = None,
            silent: Optional[bool] = False,
    ):
        self._process_received_message(message, sender, silent)
        if request_reply is False or request_reply is None and self.reply_at_receive[sender] is False:
            return
        reply = self.generate_reply(messages=self.chat_messages[sender], sender=sender)
        if reply is not None:
            self.send(reply, sender, silent=silent)

    def error_debugger(self, config, code, error_info):
        # import prompt
        from prompts_guard import CodeDebugger
        # Returns the related information to the given query.
        patience = 1
        sleep_time = 30
        openai.api_key = config["api_key"]
        engine = config["model"]
        query_message = CodeDebugger.format(subtasks=self.subtasks, code=code, error_info=error_info)
        messages = [{"role": "system",
                     "content": "You are an AI assistant that helps people debug their code. Only list one most possible reason to the errors."},
                    {"role": "user", "content": query_message}]
        # 创建 OpenAI 客户端，支持 base_url（用于本地代理或自定义 API）
        client_kwargs = {"api_key": config["api_key"]}
        if "base_url" in config:
            client_kwargs["base_url"] = config["base_url"]
        client = OpenAI(**client_kwargs)
        while patience > 0:
            patience -= 1
            try:
                response = client.chat.completions.create(
                    model=engine,
                    messages=messages,
                    temperature=0,
                    max_tokens=1000,
                    top_p=0.95,
                    frequency_penalty=0,
                    presence_penalty=0,
                    stop=None)
                prediction = response.choices[0].message.content.strip()
                if prediction != "" and prediction != None:
                    return prediction
            except Exception as e:
                print(f"Error in error_debugger: {e}")
                if sleep_time > 0:
                    time.sleep(sleep_time)
        return "Fail to diagnose the reasons to the errors."

    def execute_function(self, func_call, call_id=None):
        """Execute a function call and return the result.

        Override this function to modify the way to execute a function call.

        Args:
            func_call: a dictionary extracted from openai message at key "function_call" with keys "name" and "arguments".
            call_id: Optional call ID for tracking function calls (added for compatibility with newer autogen versions).

        Returns:
            A tuple of (is_exec_success, result_dict).
            is_exec_success (boolean): whether the execution is successful.
            result_dict: a dictionary with keys "name", "role", and "content". Value of "role" is "function".
        """
        func_name = func_call.get("name", "")
        func = self._function_map.get(func_name, None)

        is_exec_success = False
        if func is not None:
            # Extract arguments from a json-like string and put it into a dict.
            input_string = self._format_json_str(func_call.get("arguments", "{}"))
            try:
                arguments = json.loads(input_string)
            except json.JSONDecodeError as e:
                arguments = None
                if func_call["arguments"].find('\"\"\"') > 0:
                    arguments_string = func_call["arguments"].split('\"\"\"')[1]
                else:
                    arguments_string = func_call["arguments"].split(': "')[-1]
                    arguments_string = arguments_string.split('", ')[0]
                arguments = {"cell": arguments_string}
                # content = f"Error: {e}\n You argument should follow json format."
                content = f"Error: {e}\n There might be compilation errors in the code. Please check the code and try again."

            # Try to execute the function
            if arguments is not None:
                print(
                    colored(f"\n>>>>>>>> EXECUTING FUNCTION {func_name}...", "magenta"),
                    flush=True,
                )
                self.code = arguments["cell"]
                try:
                    content = func(**arguments)
                    is_exec_success = True
                except Exception as e:
                    content = f"Error: {e}"
        else:
            content = f"Error: Function {func_name} not found."
        if "error" in content or "Error" in content:
            # 使用执行用的 config（config_list[1]），如果只有一个 config 则使用 config_list[0]
            execution_config = self.config_list[1] if len(self.config_list) > 1 else self.config_list[0]
            reasons = self.error_debugger(execution_config, self.code, content)
            content = content + '\nPotential Reasons: ' + reasons

        return is_exec_success, {
            "name": func_name,
            "role": "function",
            "content": str(content),
        }

    def update_memory(self, num_shots, memory):
        self.num_shots = num_shots
        self.memory = memory


