import os
import json
import time
from typing import Optional, Literal
from langchain_core.runnables import RunnableConfig
import pandas as pd
from utils.tool_utils import init_tool_list
from state import CausalDiscoveryState, PartitionTree, PartitionNode
from agents import ExplainerAgent, DivideAgent
from conquer import conquer, fci_driven_conquer, parallel_conquer
from merge import merge
from divide import divide_loop
from langgraph.graph import StateGraph, START, END


def causal_discovery_llm(
    mode: Literal["script", "UI"] = "script",
    dataset_file: Optional[str] = None,
    dataset_content: Optional[pd.DataFrame] = None,
    description_content: Optional[dict] = None,
    ground_truth_content: Optional[dict] = None,
    parallelize: bool = False,
    config: Optional[dict] = None,
    state: Optional[dict | CausalDiscoveryState] = None,
) -> CausalDiscoveryState:
    """
    A function that invokes the main causal discovery graph for the given dataset.

    Args:
        mode (Literal["script", "UI"]): The mode of operation, either "script" or "UI".
        dataset_file (Optional[str]): The path to the dataset file (used in "script" mode).
        dataset_content (Optional[pd.DataFrame]): The dataset as a pandas DataFrame (used in "UI" mode).
        description_content (Optional[dict]): The dataset description as a dictionary (used in "UI" mode).
        ground_truth_content (Optional[dict]): The ground truth graph as a dictionary (used in "UI" mode).
        parallelize (bool): Whether to parallelize the conquer step.
        config (Optional[dict]): Configuration dictionary for the causal discovery process.
        state (Optional[dict | CausalDiscoveryState]): An optional state object to use when called from another langgraph StateGraph (with some additional adjustments)
    """

    if mode == "script":
        if not dataset_file:
            raise ValueError("dataset_file must be provided in 'script' mode.")
        
        # Load dataset from file
        data = pd.read_csv(dataset_file)

        # Load description JSON file
        description_file = os.path.join(
            os.path.dirname(dataset_file), "description.json"
        )
        if os.path.exists(description_file):
            with open(description_file, "r") as file:
                description_data = json.load(file)
                dataset_description = description_data.get("description", "")
                domain = description_data.get("domain", "")
        else:
            dataset_description = ""
            domain = ""

    elif mode == "UI":
        if dataset_content is None or description_content is None or ground_truth_content is None:
            raise ValueError(
                "dataset_content, description_content, and ground_truth_content must be provided in 'UI' mode."
            )

        # Use provided content
        data = dataset_content
        dataset_description = description_content.get("description", "")
        domain = description_content.get("domain", "")

    else:
        raise ValueError("Invalid mode. Choose either 'script' or 'UI'.")

    # Get the variable names from the dataset
    variable_names = list(data.columns.values)

    # Create the root node of the partition tree for the dataset variables
    root = PartitionNode(
        variable_names=variable_names, description=dataset_description, id="0"
    )

    causal_discovery_state = CausalDiscoveryState(
        messages=state["messages"] if state else [],
        partition_tree=PartitionTree(root=root),
        domain=domain,
        general_description="",
        causal_graph=[],
        input_token_count=0,
        output_token_count=0,
        tool_calls={},
        dataset=data,
        ground_truth_graph=ground_truth_content["edges"] if config["intermediate_metrics"] else None,
        intermediate_metrics=[] if config["intermediate_metrics"] else None,
    )

    # List of StructuredTool objects
    tool_list = init_tool_list(config["tool_list"])

    explainer_agent_instance = ExplainerAgent(tool_list=tool_list)
    # divide_agent_instance = DivideAgent(tool_list=tool_list)

    divide = lambda s: divide_loop(
        state=s, config=config
    )

    # This executes the conquer and merge logic, it has to be defined as a lambda due to the extra argument "parallelize"
    conquer_merge = lambda s: conquer_merge_loop(
        causal_discovery_state=s, parallelize=parallelize, config=config
    )

    causal_discovery_graph_builder = StateGraph(CausalDiscoveryState)
    causal_discovery_graph_builder.add_node("Explainer", explainer_agent_instance.go)
    causal_discovery_graph_builder.add_node("Divide", divide)
    causal_discovery_graph_builder.add_node("Conquer & Merge", conquer_merge)
    causal_discovery_graph_builder.add_edge(START, "Explainer")
    causal_discovery_graph_builder.add_edge("Explainer", "Divide")
    causal_discovery_graph_builder.add_edge("Divide", "Conquer & Merge")
    causal_discovery_graph_builder.add_edge("Conquer & Merge", END)
    causal_discovery_graph = causal_discovery_graph_builder.compile()

    start_time = time.time()
    causal_discovery_state = causal_discovery_graph.invoke(input=causal_discovery_state, config = RunnableConfig(recursion_limit=100))
    end_time = time.time()

    elapsed_time = end_time - start_time

    causal_discovery_state["elapsed_time"] = elapsed_time

    return causal_discovery_state


def conquer_merge_loop(
    causal_discovery_state: CausalDiscoveryState,
    parallelize=True,
    config: Optional[dict] = None,
) -> CausalDiscoveryState:
    """
    Perform the divide and conquer process over the PartitionTree object stored in the state.

    Args:
        causal_discovery_state (CausalDiscoveryState): The current state of the causal discovery process.
        parallelize (bool): Whether to run the conquer step in parallel.
        config (Optional[dict]): Configuration dictionary for the causal discovery process.

    Returns:
        CausalDiscoveryState: The updated causal discovery state after the process.
    """
    # Override parallelize if specified in the config
    if config and config.get("parallelize_conquer"):
        parallelize = config["parallelize_conquer"]

    # Start the divide and conquer process over the PartitionTree object
    # 1) While the root is not solved
    while not causal_discovery_state["partition_tree"].root.solved:
        # 2) Get unsolved leaf node IDs
        unsolved_leaves_list = causal_discovery_state[
            "partition_tree"
        ].get_unsolved_leaf_ids()

        # 3) While there are still unsolved leaves after an iteration of conquer (this while loop is just as a safety check)
        while unsolved_leaves_list:
            # 4) Solve the leaves by running conquer for each (with optional parallel execution)
            if parallelize:
                # Use the parallel_conquer function to run conquer in parallel
                causal_discovery_state = parallel_conquer(
                    causal_discovery_state=causal_discovery_state,
                    unsolved_leaves_list=unsolved_leaves_list,
                    config=config
                )

            else:
                for leaf_id in unsolved_leaves_list:
                    if config and config.get("fci_conquer"):
                        # If the config specifies to use FCI for the conquer step, we call the conquer function with the FCI option
                        causal_discovery_state = fci_driven_conquer(
                            state=causal_discovery_state, leaf_id=leaf_id, config=config
                        )
                    else:
                        causal_discovery_state = conquer(
                            state=causal_discovery_state, leaf_id=leaf_id, config=config
                        )  # Inside this function we set solved = True for each leaf

            # 5) Get IDs all "ready" parents', i.e. whose children are all solved
            ready_parents = causal_discovery_state["partition_tree"].get_ready_parents()  # Here we fetch all solved leaves, get the set of parents and check if all children are ready

            # 6) While there are still ready parents after an iteration of conquer
            while ready_parents:
                # 7) Merge the children of each parent who's "ready"
                for parent_id in ready_parents:
                    causal_discovery_state = merge(
                        state=causal_discovery_state, parent_node_id=parent_id, config=config
                    )  # In here we prune the solved children and set the solved = True for the parent

                ready_parents = causal_discovery_state["partition_tree"].get_ready_parents() # Get the ready parents again after pruning and "conquering"

            # 8) Check for unsolved leaves again after the conquer process
            unsolved_leaves_list = causal_discovery_state[
                "partition_tree"
            ].get_unsolved_leaf_ids()  # Get the new unsolved leaves

    if causal_discovery_state["partition_tree"].root.solved:
        # 6) If the root is solved, we can build the causal graph
        causal_discovery_state["causal_graph"] = causal_discovery_state[
            "partition_tree"
        ].root.edge_list
    else:
        raise ValueError("The root node of the partition tree is not solved.")

    return causal_discovery_state