from concurrent.futures import ProcessPoolExecutor, as_completed, wait

from langgraph.graph import END, START, StateGraph

from agents import CriticAgent, HypothesisAgent
from discovery_tools import ci_test
from state import CausalDiscoveryState, ConquerState
from utils.tool_utils import init_tool_list
from utils.causalgraph_utils import subset_graph
from utils.metrics import intermediate_eval_metrics

def conquer(state: CausalDiscoveryState, leaf_id: str, config: dict) -> CausalDiscoveryState:

    # If partition has just one variable, return the state without any changes
    if len(state["partition_tree"].nodes[leaf_id].variable_names) == 1:
        state["partition_tree"].set_solved(leaf_id, edge_list=[])
        return state

    # Instantiate the HypothesisAgent and CriticAgent with the provided tools
    tool_list = init_tool_list(config["tool_list"])
    hypothesis_agent_instance = HypothesisAgent(tool_list=tool_list)
    critic_agent_instance = CriticAgent(tool_list=tool_list)

    # Build the subgraph for causal discovery
    hyp_crit_subgraph_builder = StateGraph(state_schema=ConquerState)
    hyp_crit_subgraph_builder.add_node("Hypothesis", hypothesis_agent_instance.go)
    hyp_crit_subgraph_builder.add_edge(START, "Hypothesis")
    hyp_crit_subgraph_builder.add_node("Critic", critic_agent_instance.go)
    hyp_crit_subgraph_builder.add_edge("Hypothesis", "Critic")

    if not config["data_driven_conquer"]:
        # Graph without data-driven
        hyp_crit_subgraph_builder.add_edge("Critic", END)
    else:
        # Graph with data-driven
        hyp_crit_subgraph_builder.add_node("CI_test", ci_test)
        hyp_crit_subgraph_builder.add_node("Critic - Data-driven Refinement", critic_agent_instance.refine)

        # Routing function: if FCI finds unsupported edges, we call the Critic agent to refine the graph. Otherwise, we end the subgraph.
        def ci_test_outcome(state: ConquerState):
            if "unsupported_edges" in state:
                if len(state["unsupported_edges"]) > 0:
                    return "Critic - Data-driven Refinement"
                else: 
                    # It might happen that data-refinement is executed, but it does not find any unsupported edge 
                    return END
            else:
                return END

        hyp_crit_subgraph_builder.add_edge("Critic", "CI_test")
        hyp_crit_subgraph_builder.add_conditional_edges("CI_test", ci_test_outcome)

    subgraph = hyp_crit_subgraph_builder.compile()

    # Prepare input state:
    # Subset the dataset to only include the variables in nodes[leaf_id].variable_names
    subset_dataset = state["dataset"][state["partition_tree"].nodes[leaf_id].variable_names]

    input_state = ConquerState(messages=state["messages"],
                               variable_names=state["partition_tree"].nodes[leaf_id].variable_names,
                               variable_description=state["partition_tree"].nodes[leaf_id].description,
                               general_description=state["general_description"],
                               domain=state["domain"],
                               causal_graph=[],
                               input_token_count=0,
                               output_token_count=0,
                               tool_calls={},
                               dataset=subset_dataset)

    output = subgraph.invoke(input=input_state)

    # Update the partition tree with the new causal graph for the partition, and mark it as solved
    state["partition_tree"].set_solved(leaf_id, edge_list=output["causal_graph"])

    # Update token counters in the state
    state["input_token_count"] += output["input_token_count"]
    state["output_token_count"] += output["output_token_count"]

    # Update the call counter for the tools used in this task
    for tool, count in output["tool_calls"].items():
        state["tool_calls"][tool] = state["tool_calls"].get(tool, 0) + count

    if config["intermediate_metrics"]:
        from networkx.classes.digraph import DiGraph
        intermediate_metrics = intermediate_eval_metrics(
            DiGraph(subset_graph(graph=state["ground_truth_graph"], nodes=input_state["variable_names"])), 
            DiGraph(output["causal_graph"])
            )
        intermediate_metrics["node_id"] = leaf_id
        intermediate_metrics["nodes"] = input_state["variable_names"]

        state["intermediate_metrics"].append(intermediate_metrics)

    return state

def fci_driven_conquer(state: CausalDiscoveryState, leaf_id: str, config: dict) -> CausalDiscoveryState:
    # If partition has just one variable, return the state without any changes
    if len(state["partition_tree"].nodes[leaf_id].variable_names) == 1:
        state["partition_tree"].set_solved(leaf_id, edge_list=[])
        return state

    # Instantiate the HypothesisAgent and CriticAgent with the provided tools
    tool_list = init_tool_list(config["tool_list"])
    hypothesis_agent_instance = HypothesisAgent(tool_list=tool_list)
    critic_agent_instance = CriticAgent(tool_list=tool_list)

    def fci(state: ConquerState)-> ConquerState:
        from langchain_core.messages import AIMessage
        from new_fci import fci_algorithm

        # We run the FCI algorithm using the edges proposed by the agents as prior knowledge
        fci_edges = fci_algorithm(data=state["dataset"], edge_constraints=state["causal_graph"])
        state["causal_graph"] = fci_edges
        state["messages"] = [AIMessage(
            content=f"FCI algorithm executed successfully on the leaf node based on the proposed edges. Causal graph generated:{fci_edges}",
            name="FCI"
        )]

        return state

    # Build the subgraph for causal discovery
    hyp_crit_subgraph_builder = StateGraph(state_schema=ConquerState)
    hyp_crit_subgraph_builder.add_node("Hypothesis", hypothesis_agent_instance.go)
    hyp_crit_subgraph_builder.add_edge(START, "Hypothesis")
    hyp_crit_subgraph_builder.add_node("Critic", critic_agent_instance.go)
    hyp_crit_subgraph_builder.add_edge("Hypothesis", "Critic")
    hyp_crit_subgraph_builder.add_node("FCI", fci)
    hyp_crit_subgraph_builder.add_edge("Critic", "FCI")
    hyp_crit_subgraph_builder.add_edge("FCI", END)

    subgraph = hyp_crit_subgraph_builder.compile()

    # Prepare input state:
    # Subset the dataset to only include the variables in nodes[leaf_id].variable_names
    subset_dataset = state["dataset"][state["partition_tree"].nodes[leaf_id].variable_names]

    input_state = ConquerState(messages=state["messages"],
                               variable_names=state["partition_tree"].nodes[leaf_id].variable_names,
                               variable_description=state["partition_tree"].nodes[leaf_id].description,
                               general_description=state["general_description"],
                               domain=state["domain"],
                               causal_graph=[],
                               input_token_count=0,
                               output_token_count=0,
                               tool_calls={},
                               dataset=subset_dataset)

    output = subgraph.invoke(input=input_state)

    # Update the partition tree with the new causal graph for the partition, and mark it as solved
    state["partition_tree"].set_solved(leaf_id, edge_list=output["causal_graph"])

    # Update token counters in the state
    state["input_token_count"] += output["input_token_count"]
    state["output_token_count"] += output["output_token_count"]

    # Update the call counter for the tools used in this task
    for tool, count in output["tool_calls"].items():
        state["tool_calls"][tool] = state["tool_calls"].get(tool, 0) + count

    if config["intermediate_metrics"]:
        from networkx.classes.digraph import DiGraph
        intermediate_metrics = intermediate_eval_metrics(
            DiGraph(subset_graph(graph=state["ground_truth_graph"], nodes=input_state["variable_names"])), 
            DiGraph(output["causal_graph"])
            )
        intermediate_metrics["node_id"] = leaf_id
        intermediate_metrics["nodes"] = input_state["variable_names"]

        state["intermediate_metrics"].append(intermediate_metrics)

    return state


def merge_updated_states(causal_discovery_state: CausalDiscoveryState, updated_states: dict) -> CausalDiscoveryState:
    # Merge results back into the shared state
    # We also initialize token counters that sum over the token count increase across state updates
    input_token_increase = 0
    output_token_increase = 0
    starting_input_token_count = causal_discovery_state["input_token_count"]
    starting_output_token_count = causal_discovery_state["output_token_count"]
    starting_tool_calls = causal_discovery_state["tool_calls"]
    for key, item in updated_states.items():
        # Merge the updates from each task into the main state
        # We update: each node which has been solved, the "messages" list
        
        # Match the results in updated_states to the overall state based on the key
        causal_discovery_state["partition_tree"].nodes[key] = item["partition_tree"].nodes[key]

        # In case of conquer(), we update by appending the last 2 messages of the updated state
        # One should be from the HypothesisAgent and one from the CriticAgent
        causal_discovery_state["messages"].extend(item["messages"][-2:])

        # conquer() just adds the new tokens to the overall count in the state
        # So we calculate the cumulative increase in token counts and update the overall state outside the loop
        input_token_increase += item["input_token_count"] - starting_input_token_count 
        output_token_increase += item["output_token_count"] - starting_output_token_count

        # Similar thing for the tools calls, but we need to match corresponding keys in the dicts
        new_tool_calls = {tool: starting_tool_calls.get(tool, 0) for tool in item["tool_calls"].keys()}
        for tool, count in item["tool_calls"].items():
            new_tool_calls[tool] += count - starting_tool_calls.get(tool, 0)

    causal_discovery_state["input_token_count"] += input_token_increase
    causal_discovery_state["output_token_count"] += output_token_increase

    # Update the tool_calls counter in the causal_discovery_state
    for tool, count in new_tool_calls.items():
        causal_discovery_state["tool_calls"][tool] = causal_discovery_state["tool_calls"].get(tool, 0) + count

    return causal_discovery_state


def parallel_conquer(causal_discovery_state: CausalDiscoveryState, unsolved_leaves_list: list[str], config: dict) -> CausalDiscoveryState:
    '''
    Function to make parallel calls and merge results of the conquer() function on PartitionTree leaves (which are PartitionNode objects)
    Arguments:
    - causal_discovery_state: the current state of the causal discovery process
    - unsolved_leaves_list: a list of string IDs corresponding to PartitionNode objects, to be solved by the conquer() function
    The function will match the results, stored in the updated_states dict, to the overall causal_discovery_state based on the dict key
    We update:
    - The node which has been solved
    - The "messages" list
    - The token counters
    '''
    executor = ProcessPoolExecutor(max_workers=5)  # Use a thread pool to run conquer in parallel
    args = [(causal_discovery_state, leaf_id, config) for leaf_id in unsolved_leaves_list]

    if config and config.get("fci_conquer"):
        futures = {executor.submit(fci_driven_conquer, *arg): arg[1] for arg in args}  # Map futures to leaf ID arg
    else:
        futures = {executor.submit(conquer, *arg): arg[1] for arg in args}  # Map futures to leaf ID arg

    wait(futures)

    # Collect results in a dict for easy merging: the key is the id of the leaf that conquer() has acted on
    updated_states = {}
    for future in as_completed(futures):
        leaf_id = futures[future]  # Get the leaf ID associated with this future
        try:
            result = future.result()  # Get the result of the conquer call
            updated_states[leaf_id] = result  # Save the result in the dict with the leaf ID as the key
        except Exception as e:
            print(f"Error in parallel execution for leaf {leaf_id}: {e}")

    causal_discovery_state = merge_updated_states(causal_discovery_state, updated_states)

    return causal_discovery_state