from state import MergeState, CausalDiscoveryState
from agents import MergeHypothesisAgent, MergeCriticAgent
from utils.tool_utils import init_tool_list
from utils.causalgraph_utils import merge_causallearn_graphs
from langgraph.graph import StateGraph, START, END
from causallearn.graph.GeneralGraph import GeneralGraph
from causallearn.graph.GraphNode import GraphNode

def merge(state: CausalDiscoveryState, parent_node_id: str, config: dict = None) -> CausalDiscoveryState:
    """
    Function that creates a LangGraph subgraph composed of an Hypothesis and a Critic agent.
    It is used to merge the results of the conquer() function on solved PartitionTree nodes.

    These are group of variables for which the agents in the conquer() function have already found causal edges.

    We pass the groups of variables we want to merge together along with the prompts.
    The agents' job will be to connect the groups to form a larger graph. 
    This will be saved in the edge_list attribute of the parent PartitionNode object within the PartitionTree.
    This function will be called iteratively until the whole partition tree is solved.
    """

    # Instantiate the MergeHypothesisAgent and MergeCriticAgent with the provided tools
    tool_list = init_tool_list(config["tool_list"])
    merge_hypothesis_agent_instance = MergeHypothesisAgent(tool_list=tool_list)
    merge_critic_agent_instance = MergeCriticAgent(tool_list=tool_list)

    # Build the subgraph
    merge_subgraph_builder = StateGraph(state_schema=MergeState)
    merge_subgraph_builder.add_edge(START, "Merge Hypothesis")
    merge_subgraph_builder.add_node("Merge Hypothesis", merge_hypothesis_agent_instance.go)
    merge_subgraph_builder.add_node("Merge Critic", merge_critic_agent_instance.go)
    merge_subgraph_builder.add_edge("Merge Hypothesis", "Merge Critic")
    merge_subgraph_builder.add_edge("Merge Critic", END)
    merge_subgraph = merge_subgraph_builder.compile()

    # Prepare input state
    # First, gather the children of the parent, which are the nodes to be merged
    # We double check if they are all solved, and if not, raise an error
    children_ids = state["partition_tree"].edges[parent_node_id]
    children_to_merge = []
    parent_edge_list = [] # This will be either a list[list[str]] or a list[GeneralGraph], depending on wether fci_conquer is True or False
    for id in children_ids:
        child = state["partition_tree"].nodes[id]
        if child.solved != True:
            raise ValueError(f"Cannot merge unsolved node: {id}")

        else:
            children_to_merge.append(child)
            # We need to handle differently wether child.edge_list is a list[list[str]] or a GeneralGraph object
            if isinstance(child.edge_list, list):
                parent_edge_list.extend(child.edge_list)
            elif isinstance(child.edge_list, GeneralGraph):
                print(child.edge_list.get_graph_edges())
                parent_edge_list.append(child.edge_list)
            else:
                raise TypeError(f"Unsupported type for child.edge_list: {type(child.edge_list)}")

    input_state = MergeState(messages=state["messages"],
                                             groups=children_to_merge,
                                             group_connections=[],
                                             domain=state["domain"],
                                             general_description=state["general_description"],
                                             input_token_count=0,
                                             output_token_count=0,
                                             tool_calls={})

    # Run the subgraph with the input state
    output = merge_subgraph.invoke(input=input_state)

    # Here we handle the merging differently depending on the fci_conquer flag
    # If fci_conquer is True, we extend the list[list[str]] of the parent node
    if not config["fci_conquer"]:
        parent_edge_list += output["group_connections"]

        # Remove duplicates from the parent_edge_list
        parent_edge_list = list(set(parent_edge_list))

    elif config["fci_conquer"]:
        # If fci_conquer is True, we have to be careful: parent_edge_list is a list of Graph objects, 
        # while output["group_connections"] is a list[list[str]]

        # We need to merge the list into a single graph
        parent_edge_list = merge_causallearn_graphs(parent_edge_list)
        # Now parent_edge_list is a single GeneralGraph object

        # Then we need to create a new graph from the output["group_connections"]
        for edge in output["group_connections"]:
            node1 = GraphNode(name=edge[0])
            node2 = GraphNode(name=edge[1])
            parent_edge_list.add_node(node1)
            parent_edge_list.add_node(node2)
            parent_edge_list.add_directed_edge(node1=node1, node2=node2)

    # Get the output state and update the parent node with the new edge list
    state["partition_tree"].set_solved(parent_node_id, edge_list=parent_edge_list)

    # Update token counters in the state
    state["input_token_count"] += output["input_token_count"]
    state["output_token_count"] += output["output_token_count"]

    # Update the call counter for the tools used in this task
    for tool, count in output["tool_calls"].items():
        state["tool_calls"][tool] = state["tool_calls"].get(tool, 0) + count

    return state