import json
from typing import List
from langchain import hub
from langchain.agents import create_openai_functions_agent, AgentExecutor, create_tool_calling_agent

from langchain_core.prompts import ChatPromptTemplate

from src.toolfuzz.agent_executors.agent_executor import AgentResponse, TestingAgentExecutor


class OpenAIFuncFactory:
    @classmethod
    def create(cls, tool, model):
        try:
            agent = create_openai_functions_agent(model, [tool], hub.pull("hwchase17/openai-tools-agent"))
            agent = AgentExecutor(agent=agent, tools=[tool], verbose=True, return_intermediate_steps=True)
            agent.invoke({'input': "Can you call the available to you tool with random parameters?"})
            return agent
        except:
            return None


class OpenAIFunction(TestingAgentExecutor):

    def _import_environment(self):
        pass

    def __init__(self, tool, model):
        super().__init__(tool, model)
        tool.name = tool.name.replace(' ', '_').replace('-', '_')
        self.agent = OpenAIFuncFactory.create(tool, self.model)

    def __call__(self, prompt, *args, **kwargs):
        agent_results = self.agent.invoke({'input': prompt}, **kwargs)

        agent_answer = agent_results['output']
        tool_outputs = self.get_tool_output(agent_results)
        trace = self.get_trace(agent_results)

        is_tool_invoked = False
        for item in trace:
            if type(item) == dict and 'tool' in item.values():
                is_tool_invoked = True
                break

        raised_exception, error = self.extract_error(trace)

        return AgentResponse(agent_response=agent_answer, trace=trace, is_tool_invoked=is_tool_invoked,
                             is_raised_exception=raised_exception, exception=error, tool_output=tool_outputs,
                             tool_args=self.get_tool_arguments(agent_results))

    def extract_error(self, trace):
        for item in trace:
            if type(item) != dict:
                continue
            if item['role'] == 'tool':
                try:
                    content = json.loads(item['content'])
                    if 'error' in content:
                        return content['error'] == None, content['error']
                except ValueError:
                    if 'Error: ' in item['content']:
                        return 'Error: ' in item['content'], item['content'].split('\n')[0].split('Error: ')[1]
        return False, ''

    def get_trace(self, trace) -> List[dict]:
        from langchain_core.agents import AgentAction

        transformed_trace = [{'role': 'user', 'content': trace["input"]}]
        for message in trace['intermediate_steps']:
            if isinstance(message, tuple):
                if isinstance(message[0], AgentAction):
                    transformed_trace.append({'role': 'system', 'content': message[0].log})
                    transformed_trace.append({'role': 'assistant',
                                              'content': None,
                                              "tool_calls": [
                                                  {'id': 1,
                                                   'type': 'function',
                                                   'function': {'name': message[0].tool,
                                                                'arguments': message[0].tool_input}}
                                              ]})
                if isinstance(message[1], str):
                    transformed_trace.append({'role': 'tool', 'id': 1, 'content': message[1]})
        transformed_trace.append(f'assistant({trace["output"]}, None)')
        return transformed_trace

    def get_tool_output(self, agent_results):
        from langchain_core.agents import AgentAction

        tool_outputs = []
        for action, output in reversed(agent_results['intermediate_steps']):
            if isinstance(action, AgentAction) and action.tool == self.tool.name:
                tool_outputs.append(output)
        if len(tool_outputs) > 0:
            return "\n".join(tool_outputs)
        return "No tool output found"

    def get_tool_arguments(self, trace):
        from langchain_core.agents import AgentAction

        for message in trace['intermediate_steps']:
            if isinstance(message, tuple):
                if isinstance(message[0], AgentAction):
                    return message[0].tool_input
        return None

    def heartbeat(self) -> bool:
        return self.agent is not None

    def get_name(self) -> str:
        return 'openai_function'


class ToolCallingAgent(TestingAgentExecutor):

    def _import_environment(self):
        pass

    def __init__(self, tool, model):
        super().__init__(tool, model)
        prompt = ChatPromptTemplate.from_messages(
            [
                (
                    "system",
                    f"You are a helpful assistant. Make sure to use the {tool.name} tool for information.",
                ),
                ("placeholder", "{chat_history}"),
                ("human", "{input}"),
                ("placeholder", "{agent_scratchpad}"),
            ]
        )
        tool.name = tool.name.replace(' ', '_').replace('-', '_')
        self.agent = create_tool_calling_agent(self.model, [tool], prompt)
        self.agent = AgentExecutor(agent=self.agent, tools=[tool], verbose=True, return_intermediate_steps=True)

    def __call__(self, prompt, *args, **kwargs):
        agent_results = self.agent.invoke({'input': prompt}, **kwargs)

        agent_answer = agent_results['output']
        tool_outputs = self.get_tool_output(agent_results)
        trace = self.get_trace(agent_results)

        is_tool_invoked = False
        for item in trace:
            if type(item) == dict and 'tool' in item.values():
                is_tool_invoked = True
                break

        raised_exception, error = self.extract_error(trace)

        return AgentResponse(agent_response=agent_answer, trace=trace, is_tool_invoked=is_tool_invoked,
                             is_raised_exception=raised_exception, exception=error, tool_output=tool_outputs,
                             tool_args=self.get_tool_arguments(agent_results))

    def extract_error(self, trace):
        for item in trace:
            if type(item) != dict:
                continue
            if item['role'] == 'tool':
                try:
                    content = json.loads(item['content'])
                    if 'error' in content:
                        return content['error'] == None, content['error']
                except ValueError:
                    if 'Error: ' in item['content']:
                        return 'Error: ' in item['content'], item['content'].split('\n')[0].split('Error: ')[1]
        return False, ''

    def get_trace(self, trace) -> List[dict]:
        from langchain_core.agents import AgentAction

        transformed_trace = [{'role': 'user', 'content': trace["input"]}]
        for message in trace['intermediate_steps']:
            if isinstance(message, tuple):
                if isinstance(message[0], AgentAction):
                    transformed_trace.append({'role': 'system', 'content': message[0].log})
                    transformed_trace.append({'role': 'assistant',
                                              'content': None,
                                              "tool_calls": [
                                                  {'id': 1,
                                                   'type': 'function',
                                                   'function': {'name': message[0].tool,
                                                                'arguments': message[0].tool_input}}
                                              ]})
                if isinstance(message[1], str):
                    transformed_trace.append({'role': 'tool', 'id': 1, 'content': message[1]})
        transformed_trace.append(f'assistant({trace["output"]}, None)')
        return transformed_trace

    def get_tool_output(self, agent_results):
        from langchain_core.agents import AgentAction

        tool_outputs = []
        for action, output in reversed(agent_results['intermediate_steps']):
            if isinstance(action, AgentAction) and action.tool == self.tool.name:
                tool_outputs.append(output)
        if len(tool_outputs) > 0:
            return "\n".join(tool_outputs)
        return "No tool output found"

    def get_tool_arguments(self, trace):
        from langchain_core.agents import AgentAction

        for message in trace['intermediate_steps']:
            if isinstance(message, tuple):
                if isinstance(message[0], AgentAction):
                    return message[0].tool_input
        return None

    def heartbeat(self) -> bool:
        return self.agent is not None

    def get_name(self) -> str:
        return 'tool_calling_agent'

