from typing_extensions import TypedDict
from langchain_core.messages import BaseMessage
from dataclasses import dataclass, field
from typing import Optional
from pandas import DataFrame
from causallearn.graph.GeneralGraph import GeneralGraph

# First we define custom dataclasses to store lists of variables and their associated description
# PartitionNode objects are stored in a PartitionTree, a graph-like data structure
@dataclass
class PartitionNode:
    # The list of variable names in this partition
    variable_names: list[str]
    # A string description associated with this partition
    description: str
    # Unique identifier for the partition
    id: str
    # List of edges in the causal graph between the variables in this partition
    # List of causal edges proposed by the agents, or a causallearn GeneralGraph object (if fci_driven_conquer is True)
    edge_list: Optional[list[list[str]] | GeneralGraph] = None
    # Whether the partition has been solved
    solved: bool = field(default=False)

@dataclass
class PartitionTree:
    # The root of the tree
    root: PartitionNode
    # Dictionary to store nodes by their ID
    nodes: dict[str, PartitionNode]
    # Dictionary to store edges by their ID. They should be though as outgoing edges from the node corresponding to the key ID
    # This allows to track how partitions are further subdivided into subpartitions, as parent-child relationships
    edges: dict[str, list[str]]

    def __init__(self, root: PartitionNode):
        """Initialize the tree with the root node."""
        self.root = root
        self.nodes = {root.id: root}  # Add the root node to the dictionary
        self.edges = {}  # Initialize edges as an empty dictionary

    def add_node(self, variables: list[str], parent_id: str, description: str, id: str):
        """
        Add a new node to the tree. If a node with the same ID already exists, it will be overwritten.
        If an edge between the parent_id and the new node's ID already exists, no duplicate edge will be created.

        Args:
            variables (list[str]): List of variable names for the new node.
            parent_id (str): ID of the parent node.
            description (str): Description of the new node.
            id (str): Unique identifier for the new node.

        Raises:
            ValueError: If the parent ID does not exist or if the variables are not a subset of the parent's partition.
        """
        if parent_id not in self.nodes:
            raise ValueError(f"Parent ID '{parent_id}' does not exist.")

        parent_node = self.nodes[parent_id]

        # Check if the new variables are a subset of the parent's partition
        if not set(variables).issubset(set(parent_node.variable_names)):
            raise ValueError(f"Variable(s) {' and '.join(list(set(variables) - set(parent_node.variable_names)))} is/are not part of the parent node's partition.")

        # Create the new partition
        new_partition = PartitionNode(variable_names=variables, description=description, id=id)

        # Overwrite the node if it already exists
        self.nodes[id] = new_partition

        # Add an entry into the edges dict, avoiding duplicate edges
        if parent_id in self.edges:
            if id not in self.edges[parent_id]:
                self.edges[parent_id].append(id)
        else:
            self.edges[parent_id] = [id]

    def get_leaf_ids_by_solved_status(self, solved: bool) -> list[str]:
        """Retrieve all leaf node IDs based on their 'solved' attribute."""
        leaf_ids = []

        def traverse(node: PartitionNode):
            # Check if the node's id is in the edge dict. If not, or has an empty item, it has no children.
            if node.id not in self.edges or len(self.edges[node.id]) == 0:
                if node.solved == solved:  # Check if 'solved' matches the desired status
                    leaf_ids.append(node.id)
            else:
                # Recursively traverse children
                for child_id in self.edges[node.id]:
                    traverse(self.nodes[child_id])

        traverse(self.root)  # Start traversal from the root
        return leaf_ids

    def get_unsolved_leaf_ids(self) -> list[str]:
        """Retrieve all leaf node IDs with 'solved' attribute equal to False."""
        return self.get_leaf_ids_by_solved_status(solved=False)

    def get_solved_leaf_ids(self) -> list[str]:
        """Retrieve all leaf node IDs with 'solved' attribute equal to True."""
        return self.get_leaf_ids_by_solved_status(solved=True)

    def set_solved(self, node_id: str, edge_list: list[list[str]] | GeneralGraph):
        """
        Update the 'edge_list' attribute of a node, either after conquer() or merge() has been called.
        Set the 'solved' attribute to True.
        If it has any children, remove them from the nodes and edges dictionaries.
        """
        if node_id not in self.nodes:
            raise ValueError(f"Node ID '{node_id}' does not exist.")
        
        self.nodes[node_id].edge_list = edge_list
        self.nodes[node_id].solved = True

        for child_id in self.edges.get(node_id, []):
            # Ensure the child_id exists in the nodes dictionary before attempting to delete
            if child_id in self.nodes:
                # Remove the child from the nodes dictionary
                del self.nodes[child_id]

            # Remove the child id from the list in the 'edges' dictionary
            self.edges[node_id] = [child_id2 for child_id2 in self.edges[node_id] if child_id2 != child_id]

    def get_ready_parents(self) -> list[str]:
        """Retrieve all unique parent IDs whose children all have solved attribute equal to True"""
        # Get the list of solved leaves
        solved_leaves_list = self.get_solved_leaf_ids()

        # Get the set of IDs of parents with solved leaves
        parents_set = set()
        for leaf_id in solved_leaves_list:
            for node_id in self.edges.keys():
                if any(child_id == leaf_id for child_id in self.edges[node_id]):
                    parents_set.add(node_id)

        # Get all "ready" parents, i.e. whose children are all solved
        ready_parents = []
        for parent_id in parents_set:
            children_ids = self.edges.get(parent_id, [])
            children = [self.nodes[child_id] for child_id in children_ids]
            if all(child.solved for child in children):
                ready_parents.append(parent_id)

        return ready_parents

    def get_partition_ids_with_more_than_k_nodes(self, k: int) -> list[str]:
        """Retrieve all leaf partitions ids with more than k variables."""
        partitions = []

        def traverse(node: PartitionNode):
            # Check if the node is a leaf (no outgoing edges)
            if node.id not in self.edges or len(self.edges[node.id]) == 0:
                # Check if the number of variables in the node is greater than k
                if len(node.variable_names) > k:
                    partitions.append(node.id)
            else:
                # Recursively traverse children
                for child_id in self.edges[node.id]:
                    traverse(self.nodes[child_id])

        traverse(self.root)  # Start traversal from the root
        return partitions


class CausalDiscoveryState(TypedDict):
    '''
    Define the state schema for the causal discovery process.
    '''
    messages: list[str] | list[BaseMessage]
    partition_tree: PartitionTree
    domain: str
    general_description: str
    causal_graph: list[tuple[str]] = field(default_factory=list)
    input_token_count: int
    output_token_count: int
    elapsed_time: float = 0.0 # This is the time taken to run the causal discovery process
    tool_calls: dict[str, int] = field(default_factory=dict) # This is a dict to store the number of times each tool has been called
    dataset: DataFrame
    ground_truth_graph: Optional[list[list[str]]] = None # Ground truth graph for the dataset, if available, used for computing intermediate metrics after each conquer step
    intermediate_metrics: Optional[list[dict]] = None # list of dicts to store intermediate metrics after each conquer step, if the option in config is chosen

class DivideState(TypedDict):
    '''
    Define the state schema for the Divide agent subgraph.
    Variable names and dataset description store info about the dataset and are integrated into prompts.
    The causal graph gets updated by the agents, and is stored as a list of edges.
    '''
    messages: list[str] | list[BaseMessage]
    partition_tree: PartitionTree
    domain: str
    general_description: str
    partition_ids_to_divide: list[str] # List of PartitionNode IDs that have more than k variables and need to be subdivided
    processed_partition_ids: list[str] # List of PartitionNode IDs that have been processed through the Divide agents (they may or may not have been subdivided, but we do not want agents to go over them again)
    current_partition_id: str # ID of the current partition being processed by the Divide agents
    hsic_results: dict[str, list[float | list[str] | list[str]]] # Dictionary to store silhouette scores, full list of variables and outliers for each partition ID, it is the result of evaluate_clusters_hsic
    input_token_count: int
    output_token_count: int
    tool_calls: dict[str, int] = field(default_factory=dict) # This is a dict to store the number of times each tool has been called
    dataset: DataFrame

class ConquerState(TypedDict):
    '''
    Define the state schema for the Hypothesis and Critic agents subgraph.
    Variable names and dataset description store info about the dataset and are integrated into prompts.
    The causal graph gets updated by the agents, and is stored as a list of edges.
    '''
    messages: list[str] | list[BaseMessage]
    variable_names: list[str]
    variable_description: str
    domain: str
    general_description: str
    causal_graph: list[tuple[str]] | GeneralGraph # List of causal edges proposed by the agents, or a causallearn GeneralGraph object (if fci_driven_conquer is True)
    input_token_count: int
    output_token_count: int
    tool_calls: dict[str, int] = field(default_factory=dict) # This is a dict to store the number of times each tool has been called
    dataset: DataFrame
    unsupported_edges: list[tuple[str]] = field(default_factory=list) # This is a list to store the unsupported edges found by the FCI algorithm
    supported_edges: list[tuple[str]] = field(default_factory=list) # This is a list to store the supported edges found by the FCI algorithm

class MergeState(TypedDict):
    '''
    Define the state schema for the Merge Hypothesis and Critic agents subgraph.
    Variable names and dataset description store info about the dataset and are integrated into prompts.
    The causal graph gets updated by the agents, and is stored as a list of edges.
    '''
    messages: list[str] | list[BaseMessage]
    groups: list[PartitionNode]
    group_connections: list[tuple[str]]
    domain: str
    general_description: str
    input_token_count: int
    output_token_count: int
    tool_calls: dict[str, int] = field(default_factory=dict) # This is a dict to store the number of times each tool has been called