from langgraph.graph import StateGraph, START, END

from utils.tool_utils import init_tool_list
from agents import DivideHypothesisAgent, DivideCriticAgent
from state import CausalDiscoveryState, DivideState
from utils.hsic import evaluate_clusters_hsic
from langchain_core.runnables import RunnableConfig

def divide_loop(state: CausalDiscoveryState, config: dict) -> CausalDiscoveryState:

    # List of StructuredTool objects
    tool_list = init_tool_list(config["tool_list"])

    divide_hypothesis_instance = DivideHypothesisAgent(tool_list=tool_list)
    divide_critic_instance = DivideCriticAgent(tool_list=tool_list)

    def check_partitions_size(state: DivideState) -> DivideState:
        # This function checks if the partition tree has been fully divided. If it has, we end the process.
        state["partition_ids_to_divide"] = state["partition_tree"].get_partition_ids_with_more_than_k_nodes(k=config["k_threshold"])

        # Remove partitions that have already been processed as to avoid looping on the same partition
        state["partition_ids_to_divide"] = [p for p in state["partition_ids_to_divide"] if p not in state["processed_partition_ids"]]
        if state["partition_ids_to_divide"]:
            # If there are partitions to subdivide, different from the one that have been processed, we call the Divide agents to further partition it.
            state["current_partition_id"] = state["partition_ids_to_divide"][0] if state["partition_ids_to_divide"] else None

        else:
            # If there are no partitions to subdivide, we write None.
            state["current_partition_id"] = None

        return state

    def make_further_subpartitions(state: DivideState):
        if state["current_partition_id"]:
            # If there are partitions to subdivide, we call the Divide agents to further partition it.
            return "Divide Hypothesis"
        else:
            return END

    divide_graph_builder = StateGraph(DivideState)
    divide_graph_builder.add_node("Divide Hypothesis", divide_hypothesis_instance.go)
    divide_graph_builder.add_node("Divide Critic", divide_critic_instance.go)
    divide_graph_builder.add_node("Check Partitions Size", check_partitions_size)
    divide_graph_builder.add_edge(START, "Divide Hypothesis")
    divide_graph_builder.add_edge("Divide Hypothesis", "Divide Critic")
    divide_graph_builder.add_conditional_edges("Check Partitions Size", make_further_subpartitions)

    if not config["data_driven_divide"]:
        # Graph without data-driven
        divide_graph_builder.add_edge("Divide Critic", "Check Partitions Size")
    else:
        # Graph with data-driven
        divide_graph_builder.add_node("HSIC_test", hsic_test)
        divide_graph_builder.add_node("Divide Critic - Data-driven Refinement", divide_critic_instance.refine)

        divide_graph_builder.add_edge("Divide Critic", "HSIC_test")
        divide_graph_builder.add_conditional_edges("HSIC_test", hsic_test_outcome)
        divide_graph_builder.add_edge("Divide Critic - Data-driven Refinement", "Check Partitions Size")

    subgraph = divide_graph_builder.compile()

    input_state = DivideState(messages=state["messages"],
                              partition_tree=state["partition_tree"],
                              domain=state["domain"],
                              general_description=state["general_description"],
                              partition_ids_to_divide=["0"], # 0 is the ID of the root node of the PartitionTree object: the divide agents will start from that
                              current_partition_id="0",
                              processed_partition_ids=[],
                              hsic_results={},
                              input_token_count=0,
                              output_token_count=0,
                              tool_calls={},
                              dataset=state["dataset"])

    output = subgraph.invoke(input=input_state, config = RunnableConfig(recursion_limit=100))

    # Update the PartitionTree with the output from the subgraph
    state["partition_tree"] = output["partition_tree"]

    # Update token counters in the state
    state["input_token_count"] += output["input_token_count"]
    state["output_token_count"] += output["output_token_count"]

    # Update the call counter for the tools used in this task
    for tool, count in output["tool_calls"].items():
        state["tool_calls"][tool] = state["tool_calls"].get(tool, 0) + count

    return state

def hsic_test(state: DivideState) -> DivideState:
    # The function to evaluate clusters wants a dict with PartitionNode ID as keys and DataFrame slices as values
    # We first get the IDs of children of the node that is currently being processed
    clusters_dict = {child_id: None for child_id in state["partition_tree"].edges[state["current_partition_id"]]}
    for child_id in clusters_dict.keys():
        # For each child, we get the variable names and create a DataFrame slice with the variables in that cluster
        variables_in_group = state["partition_tree"].nodes[child_id].variable_names
        try:
            # We select the first 1000 rows of the DataFrame slice to avoid memory issues when computing HSIC
            clusters_dict[child_id] = state["dataset"][variables_in_group].iloc[:1000].copy()
        except KeyError as e:
            raise ValueError(f"Variable not found in dataset: {e.args[0]}")

    results = evaluate_clusters_hsic(clusters=clusters_dict)

    state["hsic_results"] = results

    return state

def hsic_test_outcome(state: DivideState):
    # This routes the StateGraph so that if HSIC test finds clusters with poor scores, we call the Divide Critic agent to refine the graph. Otherwise, we end the StateGraph.
    if state["hsic_results"]:
        # If there are clusters with poor silhouette scores, we call the Divide Critic agent to refine the graph.
        return "Divide Critic - Data-driven Refinement"
    else:
        return END