import re
from langchain_core.messages import BaseMessage, AIMessage, SystemMessage, HumanMessage, ToolMessage
from langchain_core.tools import StructuredTool
from langchain_huggingface import ChatHuggingFace
from langgraph.prebuilt import create_react_agent
from utils.causalgraph_utils import parse_causal_graph, check_nodes_exist
from utils.llm_utils import LLM_Factory
from state import CausalDiscoveryState, DivideState, ConquerState, MergeState
import prompts
from typing import Literal, Optional, Union


class Agent:
    """
    Agent abstract class for initializing the LLM and defining all useful attributes for the agents.
    This class is designed to be inherited by other agent classes, such as ExplainerAgent, DivideAgent, as well as the ConquerAgent and MergeAgent abstract classes.
    The callable agent itself is in the react_agent attribute: It is a langgraph prebuilt ReAct agent created using the specified LLM and tools.
    The methods common to all agents are:
      update_usage_metrics, which updates the token counters and the tool call counter in whatever state the agents are acting on
      invoke, that calls the actual invoke() function of the ReAct agent and also appends the last message and updates token count
    """
    def __init__(
            self, 
            tool_list: list[StructuredTool], 
            system_prompt_str: str, 
            user_prompt_str: str, 
            provider: Optional[str]  = None, 
            name: str = "Agent"
        ):
        self.llm = LLM_Factory.initialize_llm(provider)
        self.tool_list = tool_list
        self.react_agent = create_react_agent(
            model=self.llm,
            tools=self.tool_list,
            name=name
        )
        self.custom_llm = False  # Initialize with a default value
        if isinstance(self.llm, ChatHuggingFace):
            self.custom_llm = True
        self.system_prompt_str = system_prompt_str
        self.user_prompt_str = user_prompt_str
        self.system_prompt = SystemMessage(content="")
        self.user_prompt = HumanMessage(content="")
        self.previous_messages: list[BaseMessage] = []
        self.round_count = 0 

    def update_usage_metrics(self, input: dict, state: dict, output: dict):

        if type(self) is Agent:
            raise NotImplementedError("Agent is an abstract class and should not be instantiated directly.")

        # Iterate over new messages in output["messages"]
        for i in range(len(input["messages"]), len(output["messages"])):
            message = output["messages"][i]
            if self.custom_llm:
                # Here we count tokens using the custom LLM tokenizer
                message.response_metadata["input_tokens"] = self.llm.tokenizer.count_tokens(message.content)
                message.response_metadata["output_tokens"] = self.llm.tokenizer.count_tokens(message.content)
            elif isinstance(message, AIMessage):
                state["input_token_count"] += message.usage_metadata["input_tokens"]
                state["output_token_count"] += message.usage_metadata["output_tokens"]
            elif isinstance(message, ToolMessage):
                state["input_token_count"] += message.artifact.get("input_token_count", 0) if message.artifact else 0
                state["output_token_count"] += message.artifact.get("output_token_count", 0) if message.artifact else 0

                tool_name = message.name
                if tool_name in state["tool_calls"]:
                    state["tool_calls"][tool_name] += 1
                else:
                    state["tool_calls"][tool_name] = 1

    def invoke(self, input: dict, state: dict, mode: Literal["overwrite", "extend"] = "overwrite") -> dict:
        """
        Invokes the ReAct agent with the given input and updates token counters in the state, as well as the previous_messages attribute.
        Input should be a dict with key "messages", with corresponding item being a list of BaseMessage objects. 
        It returns the same dict, but with an extended list of BaseMessages.
        """
        if type(self) is Agent:
            raise NotImplementedError("Agent is an abstract class and should not be instantiated directly.")

        # Call the react agent with the input
        if isinstance(self.llm, ChatHuggingFace):
            # o1-mini does not support system or developer messages
            for k, message in enumerate(input["messages"]):
                if isinstance(message, ToolMessage):
                    input["messages"][k] = AIMessage(message.content)
                    input["messages"][k].role = "function"
        output = self.react_agent.invoke(input=input)

        # Update token counts based on the output
        self.update_usage_metrics(input=input, state=state, output=output)

        new_messages = output["messages"][len(input["messages"])-1:]

        if mode == "overwrite":
            # Save the new messages from output to state["messages"], useful for streaming
            state["messages"]=new_messages

        elif mode == "extend":
            # Append the new messages from output to state["messages"], useful when invoking the same agent more than once
            state["messages"].extend(new_messages)

        # Append the last message to the previous messages
        self.previous_messages.append(output["messages"][-1])

        return output


class ExplainerAgent(Agent):
    def __init__(
        self, 
        tool_list: list[StructuredTool], 
        provider: Optional[str]  = None, 
        name = "ExplainerAgent"
        ):
        super().__init__(
            tool_list=tool_list, 
            system_prompt_str=prompts.explainer_agent_system_prompt, 
            user_prompt_str=prompts.explainer_agent_user_prompt, 
            provider=provider, 
            name=name
            )

    def format_prompts(self, state: CausalDiscoveryState):
        tool_descriptions = "\n".join([
            f"Name: {tool.name}, Description: {tool.description}" for tool in self.tool_list
        ])

        self.system_prompt = SystemMessage(content=self.system_prompt_str.format(
                domain=state["domain"],
                tools=tool_descriptions
                ))

        self.user_prompt = HumanMessage(content=self.user_prompt_str.format(
                dataset_description=state["partition_tree"].root.description, 
                variable_names=", ".join(state["partition_tree"].root.variable_names)
                ))

        # Combine the system prompt, the list of messages in the example, and the initial user prompt
        self.previous_messages = [self.system_prompt, *prompts.create_explainer_agent_example(available_tools=[tool.name for tool in self.tool_list]), self.user_prompt]

    def parse_explainer_output(self, explainer_output: str, state: CausalDiscoveryState):
        # Parse the last message content with a regexp to find tags corresponding to both the general dataset description and variable groups
        general_description_pattern = re.compile(r"<general_description>(.*?)</general_description>", re.DOTALL)
        # Get the general description of the dataset and store it in the state
        general_description = general_description_pattern.findall(explainer_output)[0]

        # Update the general dataset description in the state and in the root node of the partition tree
        state["general_description"] = general_description
        state["partition_tree"].root.description = general_description

    def go(self, state: CausalDiscoveryState):
        self.format_prompts(state)

        input = {"messages": self.previous_messages} 
        output = self.invoke(input=input, state=state)

        # Parse the output from the divide agent and build the first layer of the partition tree
        self.parse_explainer_output(explainer_output=output["messages"][-1].content, state=state)

        return state


class DivideAgent(Agent):
    def __init__(
            self, 
            tool_list: list[StructuredTool], 
            system_prompt_str: str,
            user_prompt_str: str,
            name = str,
            provider: Optional[str] = None
        ):
        super().__init__(
            tool_list=tool_list, 
            system_prompt_str=system_prompt_str, 
            user_prompt_str=user_prompt_str, 
            provider=provider, 
            name=name
            )

    def format_prompts(self, state: DivideState):
        # A placeholder method to format prompts before invoking.
        raise NotImplementedError("Subclasses must override this method.")

    def parse_divide_output(self, divide_output: str, state: DivideState, parent_id: str) -> Union[bool, int]:
        # Each group is delimited by <group></group>, and consists of a list of nodes delimited by <nodes></nodes> and a description delimited by <description></description>
        group_pattern = re.compile(r"<group>(.*?)</group>", re.DOTALL)
        node_pattern = re.compile(r"<nodes>(.*?)</nodes>", re.DOTALL)
        description_pattern = re.compile(r"<description>(.*?)</description>", re.DOTALL)

        # Get the string segments referring to each group
        groups = group_pattern.findall(divide_output)

        if not groups:
            return False  # Return False if no groups are found

        parsed_groups = []
        for group in groups:
            variable_names = node_pattern.findall(group)
            description = description_pattern.findall(group)

            if not variable_names or not description:
                return False  # Return False if any regexp fails to find the required content

            # First separate by comma
            variable_list = variable_names[0].strip("[]").split(",")
            # Then eliminate spaces from each entry of the list 
            variable_list = [re.sub(r"[\[\]]", "", var_name.strip()) for var_name in variable_list]
            description_text = description[0]

            parsed_groups.append((variable_list, description_text))

        # Remove duplicate groups based on variable names
        unique_groups = []
        seen_variable_sets = set()
        for variables, description in parsed_groups:
            variable_set = frozenset(variables)
            if variable_set not in seen_variable_sets:
                seen_variable_sets.add(variable_set)
                unique_groups.append((variables, description))

        if len(unique_groups) == 1:
            # If only one unique group is left, return True and do not add any nodes to the partition tree
            return True

        # Clear previous children of parent_id
        child_ids = state["partition_tree"].edges.get(parent_id, [])
        for child_id in child_ids:
            if child_id in state["partition_tree"].nodes:
                del state["partition_tree"].nodes[child_id]
        state["partition_tree"].edges[parent_id] = []

        for i, (variable_list, description_text) in enumerate(unique_groups):
            try:
                state["partition_tree"].add_node(variables=variable_list, parent_id=parent_id, description=description_text, id=f"{parent_id}.{i+1}")
            except ValueError:
                return -1  # Return -1 if the output contains variables which are not present in the parent node, causing a ValueError in add_node

        return True  # Return True if parsing is successful

    def go(self, state: DivideState) -> DivideState:
        self.format_prompts(state)

        input = {"messages": self.previous_messages} 
        output = self.invoke(input=input, state=state)

        # Parse the output from the divide agent and build the first layer of the partition tree
        parse_flag = self.parse_divide_output(divide_output=output["messages"][-1].content, state=state, parent_id = state["current_partition_id"])
        # If parsing succeeds, the state is updated within the method
        retries = 0 
        while parse_flag is not True and retries < 3:
            # If output is -1, it means that the output did contain variables which were not present in the parent node
            # If output is False, the formatting was wrong.
            # In both cases, we ask the LLM to fix the output
            if parse_flag == -1:
                prompt = prompts.groups_wrong_partitioning_fix_prompt
            elif parse_flag is False:
                prompt = prompts.groups_format_fix_prompt

            # Thus, we query again the LLM until it outputs a partitioning which is valid
            input = {"messages": self.previous_messages + [HumanMessage(content=prompt)]}

            output = self.invoke(input=input, state=state, mode="extend")
            retries += 1

            parse_flag = self.parse_divide_output(divide_output=output["messages"][-1].content, state=state, parent_id = state["current_partition_id"])

        # Save the current id after it has been processed. This go() method will be called by multiple agents so check if it still is in the list of partitions to divide
        # This is to have memory of which partitions the agents have already processed and avoid looping over them in case they decide not to divide them.
        if state["current_partition_id"] not in state["processed_partition_ids"]:
            state["processed_partition_ids"].append(state["current_partition_id"])

        return state


class DivideHypothesisAgent(DivideAgent):
    def __init__(
            self, 
            tool_list: list[StructuredTool], 
            provider: Optional[str]  = None, 
            name = "DivideHypothesisAgent"
        ):
        super().__init__(
            tool_list=tool_list, 
            system_prompt_str=prompts.divide_hypothesis_agent_system_prompt, 
            user_prompt_str=prompts.divide_hypothesis_agent_user_prompt, 
            provider=provider, 
            name=name
            )

    def format_prompts(self, state: DivideAgent):
        tool_descriptions = "\n".join([
            f"Name: {tool.name}, Description: {tool.description}" for tool in self.tool_list
        ])

        self.system_prompt = SystemMessage(content=self.system_prompt_str.format(
                domain=state["domain"],
                tools=tool_descriptions
                ))

        node_id = state["current_partition_id"]  # Get the ID of the partition to divide

        self.user_prompt = HumanMessage(content=self.user_prompt_str.format(
                dataset_description=state["general_description"], 
                variable_names=", ".join(state["partition_tree"].nodes[node_id].variable_names)
                ))

        # Combine the system prompt, the list of messages in the example, and the initial user prompt
        self.previous_messages = [self.system_prompt, *prompts.create_divide_hypothesis_agent_example(available_tools=[tool.name for tool in self.tool_list]), self.user_prompt]


class DivideCriticAgent(DivideAgent):
    def __init__(
            self, 
            tool_list: list[StructuredTool], 
            provider: Optional[str]  = None, 
            name = "DivideCriticAgent"
        ):
        super().__init__(
            tool_list=tool_list, 
            system_prompt_str=prompts.divide_critic_agent_system_prompt, 
            user_prompt_str=prompts.divide_critic_agent_user_prompt, 
            provider=provider,
            name=name
            )

    def format_prompts(self, state: DivideState):
        tool_descriptions = "\n".join([
            f"Name: {tool.name}, Description: {tool.description}" for tool in self.tool_list
        ])

        self.system_prompt = SystemMessage(content=self.system_prompt_str.format(
                domain=state["domain"],
                tools=tool_descriptions
                ))

        parent_id = state["current_partition_id"]  # Get the ID of the partition to divide
        groups_ids = [child_id for child_id in state["partition_tree"].edges.get(state["current_partition_id"], [])]
        groups = [state["partition_tree"].nodes[child_id] for child_id in groups_ids]

        # Generate a string representation of the groups
        groups_str = prompts.format_group_template(groups)

        self.user_prompt = HumanMessage(content=self.user_prompt_str.format(
                dataset_description=state["general_description"],
                variable_names=", ".join(state["partition_tree"].nodes[parent_id].variable_names),
                proposed_groups=groups_str,
        ))

        # Combine the system prompt, the list of messages in the example, and the initial user prompt
        self.previous_messages = [self.system_prompt, *prompts.create_divide_hypothesis_agent_example(available_tools=[tool.name for tool in self.tool_list]), self.user_prompt]


    def refine(self, state: DivideState):

        hsic_results_str = prompts.format_bad_clusters_template(state["hsic_results"])

        refine_prompt = HumanMessage(content=prompts.divide_critic_agent_refine_prompt.format(bad_clusters=hsic_results_str))

        input = {"messages": [*self.previous_messages, refine_prompt]}

        output = self.invoke(input=input, state=state)

        # Parse the output from the divide agent and build the first layer of the partition tree
        parse_flag = self.parse_divide_output(divide_output=output["messages"][-1].content, state=state, parent_id = state["current_partition_id"])
        # If parsing succeeds, the state is updated within the method
        retries = 0 
        while parse_flag is not True and retries < 3:
            # If output is -1, it means that the output did contain variables which were not present in the parent node
            # If output is False, the formatting was wrong.
            # In both cases, we ask the LLM to fix the output
            if parse_flag == -1:
                prompt = prompts.groups_wrong_partitioning_fix_prompt
            elif parse_flag is False:
                prompt = prompts.groups_format_fix_prompt

            # Thus, we query again the LLM until it outputs a partitioning which is valid
            input = {"messages": self.previous_messages + [HumanMessage(content=prompt)]}

            output = self.invoke(input=input, state=state, mode="extend")
            retries += 1

            parse_flag = self.parse_divide_output(divide_output=output["messages"][-1].content, state=state, parent_id = state["current_partition_id"])

        # Remove the current id after it has been processed. This go() method will be called by multiple agents so check if it still is in the list of partitions to divide
        if state["current_partition_id"] in state["partition_ids_to_divide"]:
            state["partition_ids_to_divide"].remove(state["current_partition_id"])

        return state


class ConquerAgent(Agent):
        """
        Abstract Class for the Agents that perform the conquer operation, HypothesisAgent and CriticAgent.
        This is an abstract class that inherits from the Agent class. It is designed to handle the
        interaction with the language model and tools for the conquer step of the causal discovery process.
        """
        def __init__(
            self,
            tool_list: list[StructuredTool],
            system_prompt_str: str,
            user_prompt_str: str,
            name: str,
            provider: Optional[str]  = None
        ):
            super().__init__(
                tool_list=tool_list,
                system_prompt_str=system_prompt_str,
                user_prompt_str=user_prompt_str,
                provider=provider,
                name=name
            )

        def format_prompts(self, state: ConquerState):
            # A placeholder method to format prompts before invoking.
            raise NotImplementedError("Subclasses must override this method.")

        def parse_check_output(self, output: dict, state: ConquerState) -> list[tuple[str]]:
            """
            The job of ConquerAgent subclasses is to output a list of edges for the causal graph.
            This method performs parsing, checks if the output format is correct, and checks if the nodes in the edges are actually in the variable list.
            It prompts the llm again if the output needs to be adjusted.
            If everything goes smoothly, it returns the list of edges.
            """

            parsed_graph = parse_causal_graph(output["messages"][-1].content)
            while parsed_graph == -1:
                # If the parsing fails, it is likely due to poor formatting of the output by the LLM
                # Thus, if the parsing returns an empty list, we invoke the agent again with a simple prompt, asking to fix the formatting
                input = {"messages": self.previous_messages + [HumanMessage(content=prompts.edge_list_format_fix_prompt)]}

                output = self.invoke(input=input, state=state, mode="extend")
                parsed_graph = parse_causal_graph(output["messages"][-1].content)
                if parsed_graph == -1:
                    # If the parsing fails again, we raise an error
                    print("The output from the agent could not be parsed into a causal graph. Trying again...")

            nonexisting_nodes = check_nodes_exist(proposed_edges_list=parsed_graph, true_variable_list=state["variable_names"])

            while nonexisting_nodes:
                # If some nodes in the causal graph are caused by hallucinations, we prompt the LLM to review its output
                print("Some nodes in the causal graph are caused by hallucinations. Trying again...")
                input = {"messages": self.previous_messages + [HumanMessage(content=prompts.non_estisting_variables_prompt.format(variable_names=state["variable_names"], nonexisting_nodes=nonexisting_nodes))]}

                output = self.invoke(input=input, state=state, mode="extend")

                parsed_graph = parse_causal_graph(output["messages"][-1].content)
                if parsed_graph == -1:
                    # If the parsing fails again, we raise an error
                    raise ValueError("The output from the agent could not be parsed into a causal graph.")

                nonexisting_nodes = check_nodes_exist(proposed_edges_list=parsed_graph, true_variable_list=state["variable_names"])


            return parsed_graph

        def go(self, state: ConquerState):

            if type(self) is ConquerAgent:
                raise NotImplementedError("ConquerAgent is an abstract class and should not be instantiated directly.")

            self.format_prompts(state)

            input = {"messages": self.previous_messages} 
            output = self.invoke(input=input, state=state)

            parsed_graph = self.parse_check_output(output=output, state=state)

            state["causal_graph"] = parsed_graph

            # Update round counter
            self.round_count += 1
            return state


class HypothesisAgent(ConquerAgent):
    def __init__(
            self, 
            tool_list: list[StructuredTool], 
            provider: Optional[str]  = None, 
            name = "HypothesisAgent"
        ):
        super().__init__(
            tool_list=tool_list, 
            system_prompt_str=prompts.conquer_hypothesis_agent_system_prompt, 
            user_prompt_str=prompts.conquer_hypothesis_agent_user_prompt, 
            provider=provider, 
            name=name
            )
        
    def format_prompts(self, state: ConquerState):
        tool_descriptions = "\n".join([
            f"Name: {tool.name}, Description: {tool.description}" for tool in self.tool_list
        ])

        self.system_prompt = SystemMessage(content=self.system_prompt_str.format(
                domain=state["domain"],
                tools=tool_descriptions
                ))

        self.user_prompt = HumanMessage(content=self.user_prompt_str.format(
            general_description=state["general_description"],
            variable_description=state["variable_description"], 
            variable_names=", ".join(state["variable_names"])
            ))

        if self.round_count == 0:
            # For the initial call to the agent, we combine the system prompt, the list of messages in the example, and the initial user prompt
            # We store them as a list within the previous_messages attribute
            self.previous_messages = [self.system_prompt, *prompts.create_conquer_hypothesis_agent_example(available_tools=[tool.name for tool in self.tool_list]), self.user_prompt]
        elif self.round_count > 0:
            # Here, we make the message history for rounds > 1
            round_user_prompt = HumanMessage(content="")
            self.previous_messages.append(round_user_prompt)


class CriticAgent(ConquerAgent):
    def __init__(
            self, 
            tool_list: list[StructuredTool], 
            provider: Optional[str]  = None, 
            name = "CriticAgent"
        ):
        super().__init__(
            tool_list=tool_list,
            system_prompt_str=prompts.conquer_critic_agent_system_prompt,
            user_prompt_str=prompts.conquer_critic_agent_user_prompt, 
            provider=provider, 
            name=name
            )

    def format_prompts(self, state: ConquerState):
        tool_descriptions = "\n".join([
            f"Name: {tool.name}, Description: {tool.description}" for tool in self.tool_list
        ])

        self.system_prompt = SystemMessage(content = self.system_prompt_str.format(
            domain=state["domain"],
            tools=tool_descriptions
            ))

        self.user_prompt = HumanMessage(content = self.user_prompt_str.format(
            general_description=state["general_description"],
            variable_description=state["variable_description"], 
            variable_names=", ".join(state["variable_names"]),
            causal_graph=state["causal_graph"]
            ))

        if self.round_count == 0:
            self.previous_messages = [self.system_prompt, *prompts.create_conquer_critic_agent_example(available_tools=[tool.name for tool in self.tool_list]), self.user_prompt]

    def refine(self, state: ConquerState):
        """
        Refines the causal graph by asking the CriticAgent to search for 
        additional evidence on a specific set of edges.
        Those edges have been found via data-driven approaches, and we are 
        asking the critic agent to decide whether those are present or not
        """

        if state["unsupported_edges"] == []:
            raise ValueError("The list of unsupported edges is empty. The CriticAgent refinement should not be called.")

        self.format_prompts(state)

        # Make refine prompt
        refine_prompt = HumanMessage(content = prompts.critic_refine_user_prompt.format(
            edges=state["unsupported_edges"]
            ))

        input = {"messages": [*self.previous_messages, refine_prompt]} 
        output = self.invoke(input=input, state=state)

        parsed_edges = self.parse_check_output(output=output, state=state)

        state["causal_graph"] = state["supported_edges"] + parsed_edges

        return state

class MergeAgent(Agent):
        """
        Abstract Class for the Agents that perform the merge operation, MergeHypothesisAgent and MergeCriticAgent.
        This is an abstract class that inherits from the Agent class. It is designed to handle the
        interaction with the language model and tools for the merge step of the causal discovery process.
        """
        def __init__(
            self,
            tool_list: list[StructuredTool],
            system_prompt_str: str,
            user_prompt_str: str,
            name: str,
            provider: Optional[str]  = None
        ):
            super().__init__(
                tool_list=tool_list,
                system_prompt_str=system_prompt_str,
                user_prompt_str=user_prompt_str,
                provider=provider,
                name=name
            )
            
        def format_prompts(self, state: MergeState):
            # A placeholder method to format prompts before invoking.
            raise NotImplementedError("Subclasses must override this method.")

        def parse_check_output(self, output: dict, state: MergeState) -> list[list[str]]:
            """
            The job of MergeAgent subclasses is to output a list of edges for the causal graph.
            This method performs parsing, checks if the output format is correct, and checks if the nodes in the edges are actually in the variable list.
            It prompts the llm again if the output needs to be adjusted.
            If everything goes smoothly, it returns the list of edges.
            """

            parsed_graph = parse_causal_graph(output["messages"][-1].content)
            if parsed_graph == -1:
                # If the parsing fails, it is likely due to poor formatting of the output by the LLM
                # Thus, if the parsing returns an empty list, we invoke the agent again with a simple prompt, asking to fix the formatting
                input = {"messages": self.previous_messages + [HumanMessage(content=prompts.edge_list_format_fix_prompt)]}

                output = self.invoke(input=input, state=state, mode="extend")
                parsed_graph = parse_causal_graph(output["messages"][-1].content)
                if parsed_graph == -1:
                    # If the parsing fails again, we raise an error
                    raise ValueError("The output from the agent could not be parsed into a causal graph.")

            # Get all the variable names, stored in the attribute of each PartitionNode in state["groups"]
            variable_names_in_groups = [var for node in state["groups"] for var in node.variable_names]

            nonexisting_nodes = check_nodes_exist(proposed_edges_list=parsed_graph, true_variable_list=variable_names_in_groups)

            if nonexisting_nodes:
                # If some nodes in the causal graph are caused by hallucinations, we prompt the LLM to review its output
                print("Some nodes in the causal graph are caused by hallucinations. Trying again...")
                input = {"messages": self.previous_messages + [HumanMessage(content=prompts.non_estisting_variables_prompt.format(variable_names=variable_names_in_groups, nonexisting_nodes=nonexisting_nodes))]}

                output = self.invoke(input=input, state=state, mode="extend")

                parsed_graph = parse_causal_graph(output["messages"][-1].content)
                if parsed_graph == -1:
                    # If the parsing fails again, we raise an error
                    raise ValueError("The output from the agent could not be parsed into a causal graph.")

            return parsed_graph

        def go(self, state: MergeState) -> MergeState:

            if type(self) is MergeAgent:
                raise NotImplementedError("MergeAgent is an abstract class and should not be instantiated directly.")

            self.format_prompts(state)

            input = {"messages": self.previous_messages} 
            output = self.invoke(input=input, state=state)

            parsed_graph = self.parse_check_output(output=output, state=state)

            state["group_connections"] = parsed_graph

            # Update round counter
            self.round_count += 1
            return state

class MergeHypothesisAgent(MergeAgent):
        def __init__(
            self, 
            tool_list: list[StructuredTool], 
            provider: Optional[str]  = None, 
            name = "MergeHypothesisAgent"
        ):
            super().__init__(
                tool_list=tool_list, 
                system_prompt_str=prompts.merging_hypothesis_agent_system_prompt, 
                user_prompt_str=prompts.merging_hypothesis_agent_user_prompt, 
                provider=provider, 
                name=name
                )

        def format_prompts(self, state: MergeState):
            tool_descriptions = "\n".join([
                f"Name: {tool.name}, Description: {tool.description}" for tool in self.tool_list
            ])

            self.system_prompt = SystemMessage(content=self.system_prompt_str.format(
                    domain=state["domain"],
                    tools=tool_descriptions
                    ))

            # Generate a string representation of the groups
            groups_str = prompts.format_group_template(state["groups"])

            self.user_prompt = HumanMessage(content=self.user_prompt_str.format(
                general_description=state["general_description"],
                groups=groups_str
                ))

            # Combine the system prompt, the list of messages in the example, and the user prompt
            self.previous_messages = [self.system_prompt, *prompts.create_merging_hypothesis_agent_example(available_tools=[tool.name for tool in self.tool_list]), self.user_prompt]


class MergeCriticAgent(MergeAgent):
        def __init__(
            self,
            tool_list: list[StructuredTool], 
            provider: Optional[str]  = None, 
            name = "MergeCriticAgent"
        ):
            super().__init__(
                tool_list=tool_list, 
                system_prompt_str=prompts.merging_critic_agent_system_prompt, 
                user_prompt_str=prompts.merging_critic_agent_user_prompt, 
                provider=provider, 
                name=name
                )

        def format_prompts(self, state: MergeState):
            tool_descriptions = "\n".join([
                f"Name: {tool.name}, Description: {tool.description}" for tool in self.tool_list
            ])

            self.system_prompt = SystemMessage(content=self.system_prompt_str.format(
                    domain=state["domain"],
                    tools=tool_descriptions
                    ))

            # Generate a string representation of the groups
            groups_str = prompts.format_group_template(state["groups"])

            self.user_prompt = HumanMessage(content=self.user_prompt_str.format(
                general_description=state["general_description"],
                groups=groups_str,
                group_connections=", ".join(f"({a}, {b})" for a, b in state["group_connections"])
                ))

            # Combine the system prompt, the list of messages in the example, and the user prompt
            self.previous_messages = [self.system_prompt, *prompts.create_merging_critic_agent_example(available_tools=[tool.name for tool in self.tool_list]), self.user_prompt]