"""
This module implements the KnowledgeGraph class for representing mathematical knowledge.

The knowledge graph is a directed graph where:
- Nodes represent mathematical entities (concepts, conjectures, theorems)
- Edges represent relationships between entities (e.g., construction, implication)
"""

from datetime import datetime
from enum import Enum, auto
from typing import List, Optional, Set, Dict, Any, Tuple, Union
import networkx as nx
from dataclasses import dataclass
from graphviz import Digraph
import os
import json
import pickle
import dill
import logging

from frame.knowledge_base.entities import (
    Entity, Concept, Conjecture, Theorem, Example, ExampleStructure
)
from frame.productions.base import ProductionRule
from frame.provers.base_prover import ProofResult
from frame.provers.proof import Proof

logger = logging.getLogger(__name__)

class NodeType(Enum):
    CONCEPT = "concept"
    CONJECTURE = "conjecture"
    THEOREM = "theorem"


class RelationType(Enum):
    """Types of relationships between entities in the knowledge graph.

    Note(_; 2/18): Currently only implementing the core relationships needed.
    Future expansion should include more relationship types for:
    - Logical relationships (e.g., IMPLIES, IMPLIED_BY, CONTRADICTS)
    - Structural relationships (e.g., SPECIALIZATION_OF, GENERALIZATION_OF, ANALOGOUS_TO, DUAL_OF)
    """

    # Construction relationships
    USED_IN_CONSTRUCTION = auto()  # Entity was used to construct another entity
    CONSTRUCTED_FROM = auto()  # Entity was constructed using another entity


@dataclass
class ConstructionStep:
    """Records how an entity was constructed within the graph."""

    rule: ProductionRule
    input_node_ids: List[str]  # References to other nodes in the graph
    parameters: Dict[str, Any]
    timestamp: datetime


class KnowledgeGraph(nx.DiGraph):
    """A directed graph representing mathematical knowledge.

    Inherits from NetworkX's DiGraph to leverage its graph algorithms
    while adding domain-specific functionality for mathematical concepts.
    """

    def __init__(self):
        """Initialize an empty knowledge graph."""
        super().__init__()
        self._concept_counter = 0
        self._conjecture_counter = 0
        self._theorem_counter = 0
        self._step_counter = 0  # Track number of construction steps
        self._available_ids = {
            "concept": set(),
            "conjecture": set(),
            "theorem": set()
        }
        
        # Track all instances by their structure
        # An instance is any concrete value that has appeared in an example or nonexample
        # for any entity in the graph. This allows us to track all values that have
        # appeared in mathematical discovery and propagate them to relevant entities.
        # We track by component_types only to avoid duplication from different concept types
        self._instances_by_structure = {}  # Maps Tuple[ExampleType, ...] to set of values

    # Node management methods
    def _generate_node_id(self, prefix: str) -> str:
        """Generate a unique node ID with the given prefix.
        
        This method will first try to reuse an available ID from the prefix's available set.
        If no available IDs exist, it will generate a new one using the counter.
        """
        # First check if we have any available IDs to reuse
        if self._available_ids[prefix]:
            # Get the smallest available ID
            node_id = f"{prefix}_{min(self._available_ids[prefix])}"
            self._available_ids[prefix].remove(int(node_id.split('_')[1]))
            return node_id
            
        # If no available IDs, generate a new one
        counter = getattr(self, f"_{prefix}_counter")
        setattr(self, f"_{prefix}_counter", counter + 1)
        return f"{prefix}_{counter}"
    
    def _find_duplicate_entity(self, entity: Entity) -> Optional[str]:
        """
        Check if an entity already exists in the graph.
        Returns the node ID of the duplicate if found, None otherwise.
        """
        # Get all nodes of the same type
        node_type = NodeType.CONCEPT if isinstance(entity, Concept) else \
                   NodeType.CONJECTURE if isinstance(entity, Conjecture) else \
                   NodeType.THEOREM if isinstance(entity, Theorem) else None
                   
        if node_type is None:
            return None
            
        # Check each node of the same type
        for node_id, node_data in self.nodes(data=True):
            if node_data['node_type'] == node_type:
                if node_data['entity'] == entity:
                    return node_id
        return None

    def _has_edge_type(self, source: str, target: str, relation: RelationType) -> bool:
        """Check if an edge of the given type already exists between the nodes."""
        if not self.has_edge(source, target):
            return False
        
        # Get all edge data between source and target
        edge_data_dict = self.get_edge_data(source, target)
        
        # Check each edge data
        for d in edge_data_dict.values():
            # Handle case where d is a RelationType directly
            if isinstance(d, RelationType):
                if d == relation:
                    return True
            # Handle case where d is a dictionary with 'relation' key
            elif isinstance(d, dict) and 'relation' in d:
                if d['relation'] == relation:
                    return True
        
        return False

    def add_concept(self, concept: Concept, construction_step: Optional[ConstructionStep] = None) -> str:
        """Add a concept to the knowledge graph if it doesn't already exist."""
        # Check if concept already exists
        existing_id = self._find_duplicate_entity(concept)
        if existing_id is not None:
            logger.warning(f"Concept {concept.name} already exists with ID {existing_id}")
            return existing_id

        # If concept doesn't exist, add it normally
        node_id = self._generate_node_id("concept")
        super().add_node(
            node_id,
            entity=concept,
            node_type=NodeType.CONCEPT,
            construction_step=construction_step,
            creation_step=self._step_counter,  # Store the current step count
        )
        self._step_counter += 1  # Increment the step counter

        if construction_step:
            for input_id in construction_step.input_node_ids:
                self.add_construction_edge(input_id, node_id)
                
        # Track any existing examples
        if hasattr(concept, "examples"):
            for example in concept.examples.get_examples():
                self._track_instance(example)
            for example in concept.examples.get_nonexamples():
                self._track_instance(example)

        return node_id
    
    def add_conjecture(self, conjecture: Conjecture, construction_step: Optional[ConstructionStep] = None) -> str:
        """Add a conjecture to the knowledge graph if it doesn't already exist."""
        # Check if conjecture already exists
        existing_id = self._find_duplicate_entity(conjecture)
        if existing_id is not None:
            logger.warning(f"Conjecture already exists with ID {existing_id}")
            return existing_id
            
        # If conjecture doesn't exist, add it normally
        node_id = self._generate_node_id("conjecture")
        super().add_node(
            node_id,
            entity=conjecture,
            node_type=NodeType.CONJECTURE,
            construction_step=construction_step,
            creation_step=self._step_counter,  # Store the current step count
        )
        self._step_counter += 1  # Increment the step counter

        if construction_step:
            for input_id in construction_step.input_node_ids:
                self.add_construction_edge(input_id, node_id)

        return node_id
    
    def add_theorem(self, theorem: Theorem, construction_step: Optional[ConstructionStep] = None) -> str:
        """Add a theorem to the knowledge graph if it doesn't already exist."""
        # Check if theorem already exists
        existing_id = self._find_duplicate_entity(theorem)
        if existing_id is not None:
            logger.warning(f"Theorem already exists with ID {existing_id}")
            return existing_id
            
        # If theorem doesn't exist, add it normally
        node_id = self._generate_node_id("theorem")
        super().add_node(
            node_id,
            entity=theorem,
            node_type=NodeType.THEOREM,
            construction_step=construction_step,
            creation_step=self._step_counter,  # Store the current step count
        )
        self._step_counter += 1  # Increment the step counter

        if construction_step:
            for input_id in construction_step.input_node_ids:
                self.add_construction_edge(input_id, node_id)

        return node_id

    def get_step_counter(self) -> int:
        """Returns the current step counter value."""
        return self._step_counter

    def get_entity_creation_step(self, entity_id: str) -> int:
        """Returns the step at which an entity was created."""
        if entity_id not in self:
            raise ValueError(f"Entity {entity_id} not found in graph")
        return self.nodes[entity_id].get("creation_step", 0)
        
    def get_entity_step_age(self, entity_id: str) -> int:
        """Returns the age of an entity in terms of construction steps."""
        if entity_id not in self:
            raise ValueError(f"Entity {entity_id} not found in graph")
        creation_step = self.nodes[entity_id].get("creation_step", 0)
        return self._step_counter - creation_step

    def get_node(self, node_id: str) -> Tuple[Entity, NodeType, Optional[ConstructionStep]]:
        """Get a node's data from the graph."""
        if node_id not in self:
            raise ValueError(f"Node {node_id} not found in graph")

        node_data = self.nodes[node_id]
        return (
            node_data["entity"],
            node_data["node_type"],
            node_data.get("construction_step"),
        )

    # Edge and relationship methods
    def add_relationship(self, source: str, target: str, relation: RelationType):
        """Add a relationship edge between two entities if it doesn't already exist."""
        if source not in self or target not in self:
            raise ValueError("Both source and target must exist in the graph")
            
        # Check if edge of this type already exists
        if not self._has_edge_type(source, target, relation):
            super().add_edge(source, target, relation=relation)
    
    def add_construction_edge(self, source: str, target: str):
        """Add a construction edge indicating one entity was used to construct another.
        The edge goes from the input (source) to the output (target)."""
        self.add_relationship(source, target, RelationType.USED_IN_CONSTRUCTION)

    # Query methods
    def get_related_entities(
        self, entity_id: str, relation_type: Optional[RelationType] = None
    ) -> List[str]:
        """Get entities related to the given entity by the specified relationship type."""
        if entity_id not in self:
            raise ValueError(f"Entity {entity_id} not found")

        if relation_type is None:
            return list(self.neighbors(entity_id))

        related = []
        for neighbor in self.neighbors(entity_id):
            edge_data = self.get_edge_data(entity_id, neighbor)
            if edge_data["relation"] == relation_type:
                related.append(neighbor)

        return related

    def get_related_concepts(
        self, entity_id: str, relation_type: Optional[RelationType] = None
    ) -> List[str]:
        """Get concepts related to the given entity."""
        related = self.get_related_entities(entity_id, relation_type)
        return [r for r in related if self.nodes[r]["node_type"] == NodeType.CONCEPT]

    def get_related_conjectures(
        self, entity_id: str, relation_type: Optional[RelationType] = None
    ) -> List[str]:
        """Get conjectures related to the given entity."""
        related = self.get_related_entities(entity_id, relation_type)
        return [r for r in related if self.nodes[r]["node_type"] == NodeType.CONJECTURE]

    def get_related_theorems(
        self, entity_id: str, relation_type: Optional[RelationType] = None
    ) -> List[str]:
        """Get theorems related to the given entity."""
        related = self.get_related_entities(entity_id, relation_type)
        return [r for r in related if self.nodes[r]["node_type"] == NodeType.THEOREM]

    def get_dependencies(self, node_id: str) -> Set[str]:
        """Get all nodes that this node depends on (directly or indirectly)."""
        return set(nx.ancestors(self, node_id))

    def get_dependents(self, node_id: str) -> Set[str]:
        """Get all nodes that depend on this node (directly or indirectly)."""
        return set(nx.descendants(self, node_id))

    def construction_depth(self, node_id: str) -> int:
        """Calculate the construction depth using NetworkX's algorithms."""
        if node_id not in self:
            raise ValueError(f"Node {node_id} not found in graph")

        # Find all paths from root nodes to this node
        roots = [n for n in self.nodes() if self.in_degree(n) == 0]
        depths = []
        for root in roots:
            try:
                path_length = nx.shortest_path_length(self, root, node_id)
                depths.append(path_length)
            except nx.NetworkXNoPath:
                continue

        return max(depths) if depths else 0

    # Removal methods
    def remove_node(self, node_id: str) -> Tuple[List[str], List[ConstructionStep]]:
        """Remove a node and all its dependent nodes from the graph.
        
        Args:
            node_id: The ID of the node to remove
            
        Returns:
            Tuple[List[str], List[ConstructionStep]]: A tuple containing:
                - List of all node IDs that were removed
                - List of construction steps for all removed entities
        """
        # Get all nodes that depend on this node (both directly and indirectly)
        dependents = list(nx.descendants(self, node_id))
        
        # Collect all construction steps and removed nodes
        removed_steps = []
        removed_nodes = []
        
        # Collect construction steps from dependent nodes first
        for dependent_id in dependents:
            node = self.nodes[dependent_id]
            if node.get("construction_step"):
                removed_steps.append(node["construction_step"])
            # Remove edges and node
            self.remove_edges_from(list(self.edges(dependent_id)))
            super().remove_node(dependent_id)
            # Add the ID to available IDs
            prefix, num = dependent_id.split('_')
            self._available_ids[prefix].add(int(num))
            removed_nodes.append(dependent_id)
            
        # Collect construction step from the target node
        node = self.nodes[node_id]
        if node.get("construction_step"):
            removed_steps.append(node["construction_step"])
            
        # Finally remove edges and the node itself
        self.remove_edges_from(list(self.edges(node_id)))
        super().remove_node(node_id)
        # Add the ID to available IDs
        prefix, num = node_id.split('_')
        self._available_ids[prefix].add(int(num))
        removed_nodes.append(node_id)
        
        return removed_nodes, removed_steps

    def remove_concept(self, concept_id: str) -> Tuple[List[str], List[ConstructionStep]]:
        """Remove a concept and its dependents from the graph.
        
        Returns:
            Tuple[List[str], List[ConstructionStep]]: A tuple containing:
                - List of all node IDs that were removed
                - List of construction steps for all removed entities
        """
        if concept_id not in self:
            raise ValueError(f"Concept {concept_id} not found")

        node = self.nodes[concept_id]
        if node["node_type"] != NodeType.CONCEPT:
            raise ValueError(f"Node {concept_id} is not a concept")

        return self.remove_node(concept_id)

    def remove_conjecture(self, conjecture_id: str) -> Tuple[List[str], List[ConstructionStep]]:
        """Remove a conjecture and its dependents from the graph.
        
        Returns:
            Tuple[List[str], List[ConstructionStep]]: A tuple containing:
                - List of all node IDs that were removed
                - List of construction steps for all removed entities
        """
        if conjecture_id not in self:
            raise ValueError(f"Conjecture {conjecture_id} not found")

        node = self.nodes[conjecture_id]
        if node["node_type"] != NodeType.CONJECTURE:
            raise ValueError(f"Node {conjecture_id} is not a conjecture")

        return self.remove_node(conjecture_id)

    def remove_theorem(self, theorem_id: str) -> Tuple[List[str], List[ConstructionStep]]:
        """Remove a theorem and its dependents from the graph.
        
        Returns:
            Tuple[List[str], List[ConstructionStep]]: A tuple containing:
                - List of all node IDs that were removed
                - List of construction steps for all removed entities
        """
        if theorem_id not in self:
            raise ValueError(f"Theorem {theorem_id} not found")

        node = self.nodes[theorem_id]
        if node["node_type"] != NodeType.THEOREM:
            raise ValueError(f"Node {theorem_id} is not a theorem")

        return self.remove_node(theorem_id)

    def _track_instance(self, example: Example):
        """Track a new instance in the knowledge graph.
        
        An instance is any concrete value that has appeared in an example or nonexample
        for any entity in the graph. This allows us to track all values that have
        appeared in mathematical discovery and propagate them to relevant entities.
        
        The instances are maintained in sorted order based on the natural ordering
        of their component types (e.g., numeric values are ordered by magnitude,
        sets by cardinality, etc.).
        
        We track by component_types only to avoid duplication from different concept types
        (e.g., function vs predicate) that have the same value structure.
        
        Args:
            example: The example/nonexample to track
        """
        # Get the component types from the example structure
        component_types = example.example_structure.component_types
        
        # Add to instances_by_structure
        if component_types not in self._instances_by_structure:
            self._instances_by_structure[component_types] = set()
            
        # Add just the value to the set
        self._instances_by_structure[component_types].add(example.value)
    
    def get_instances_by_structure(self, structure: ExampleStructure) -> Set[Example]:
        """Get all instances with a given example structure.
        
        Args:
            structure: The example structure to look for
            
        Returns:
            Set of examples/nonexamples with the given structure, in sorted order
        """
        values = self._instances_by_structure.get(structure.component_types, set())
        # Convert back to Example objects for compatibility
        return sorted(Example(value, structure, True) for value in values)
    
    def propagate_instance(self, instance: Example):
        """Propagate an instance through the knowledge graph.
        
        An instance is any concrete value that has appeared in an example or nonexample
        for any entity in the graph. This method tracks the instance and propagates it
        to all relevant entities that can verify it.
        
        Args:
            instance: The example/nonexample to propagate
            
        Returns:
            List of tuples (entity_id, action) where action is either "removed" or "updated"
        """
        affected_entities = []
        
        # Track the instance first
        self._track_instance(instance)

        for node_id, node in list(
            self.nodes(data=True)
        ):  # Use list to allow modification during iteration
            entity = node["entity"]

            # Check if the entity has the same example structure
            if (
                hasattr(entity, "examples")
                and entity.examples.example_structure
                == instance.example_structure
            ):

                # For conjectures, check if this disproves them
                if node["node_type"] == NodeType.CONJECTURE:
                    try:
                        if not entity.verify_example(instance.value):
                            # This conjecture is disproven
                            self.remove_conjecture(node_id)
                            affected_entities.append((node_id, "removed"))
                            continue
                    except NotImplementedError:
                        pass

                # For concepts, add to examples/nonexamples
                if node["node_type"] == NodeType.CONCEPT:
                    try:
                        is_example = entity.verify_example(instance.value)
                        if is_example:
                            entity.add_example(instance.value)
                        else:
                            entity.add_nonexample(instance.value)
                        affected_entities.append((node_id, "updated"))
                    except NotImplementedError:
                        pass

        return affected_entities

    def get_all_concepts(self) -> List[str]:
        """Get all concept nodes in the graph."""
        return [
            node_id
            for node_id, data in self.nodes(data=True)
            if data["node_type"] == NodeType.CONCEPT
        ]

    def get_all_conjectures(self) -> List[str]:
        """Get all conjecture nodes in the graph."""
        return [
            node_id
            for node_id, data in self.nodes(data=True)
            if data["node_type"] == NodeType.CONJECTURE
        ]

    def get_all_theorems(self) -> List[str]:
        """Get all theorem nodes in the graph."""
        return [
            node_id
            for node_id, data in self.nodes(data=True)
            if data["node_type"] == NodeType.THEOREM
        ]

    # Visualization methods
    def _wrap_label(self, text: str, max_chars: int = 30) -> str:
        """Helper function to wrap long labels by inserting newlines."""
        words = text.split('_')
        lines = []
        current_line = []
        current_length = 0
        
        for word in words:
            if current_length + len(word) + 1 <= max_chars:  # +1 for the underscore
                current_line.append(word)
                current_length += len(word) + 1
            else:
                if current_line:
                    lines.append('_'.join(current_line))
                current_line = [word]
                current_length = len(word)
                
        if current_line:
            lines.append('_'.join(current_line))
            
        return '\\n'.join(lines)

    def _get_rule_name(self, rule: ProductionRule) -> str:
        """Helper function to get a clean rule name for visualization.
        
        Args:
            rule: The production rule to get the name for
            
        Returns:
            A clean, formatted name for the rule
        """
        # Get the rule name and remove any "Rule" suffix if present
        name = rule.name
        if name.endswith("Rule"):
            name = name[:-4]
        return name

    def visualize_construction_tree(self, output_file: str = None):
        """Visualize the construction tree using NetworkX's layout algorithms.
        
        This visualization shows how entities are constructed from other entities,
        with edges representing construction relationships and labels showing the
        production rules used.
        
        Args:
            output_file: Path to save the visualization (without extension)
        """
        if output_file:
            print(f"\nGenerating construction tree visualization: {output_file}")

        try:
            # Initialize graphviz object with a unique identifier to avoid conflicts
            import tempfile
            import uuid
            import os
            
            # Create a unique temporary directory for this visualization
            # This helps avoid conflicts when multiple visualizations run in parallel
            temp_dir = tempfile.mkdtemp(prefix=f"graph_viz_{uuid.uuid4().hex}_")
            
            dot = Digraph(comment="Construction Tree", format="png", 
                         directory=temp_dir, filename="construction_tree")
            dot.attr(rankdir="TB", nodesep="0.5", ranksep="0.5")

            # Initialize legend_nodes set to track which node types we've added to the legend
            legend_nodes = set()

            # Create legend subgraph with vertical layout
            with dot.subgraph(name="cluster_legend") as legend:
                legend.attr(label="Legend", rankdir="TB")
                colors = {
                    NodeType.CONCEPT: "lightblue",
                    NodeType.CONJECTURE: "lightgreen",
                    NodeType.THEOREM: "lightpink",
                }
                override_colors = {
                    NodeType.CONCEPT: "purple",
                    NodeType.CONJECTURE: "pink",
                    NodeType.THEOREM: "orange",
                }
                # Create a vertical chain of legend nodes
                prev_node = None
                for node_type, color in colors.items():
                    # Regular node
                    curr_node = f"legend_{node_type.name}"
                    legend.node(
                        curr_node,
                        node_type.name,
                        style="filled",
                        fillcolor=color,
                        width="0.75",
                        height="0.5",
                        fontsize="10",
                    )
                    if prev_node:
                        legend.edge(prev_node, curr_node, style="invis")
                    prev_node = curr_node
                    
                    # Override node
                    override_node = f"legend_{node_type.name}_override"
                    legend.node(
                        override_node,
                        f"{node_type.name}\n(Ground Truth)",
                        style="filled",
                        fillcolor=override_colors[node_type],
                        width="0.75",
                        height="0.5",
                        fontsize="10",
                    )
                    legend.edge(prev_node, override_node, style="invis")
                    prev_node = override_node
                    legend_nodes.add(node_type)  # Add to legend_nodes set
                
                # Add root node legend entry at the bottom
                root_node = "legend_root"
                legend.node(
                    root_node,
                    "Root Node\n(No inputs)",
                    style="filled,dashed",
                    fillcolor="white",
                    shape="ellipse",
                    fontsize="10",
                    margin="0.2,0.1"
                )
                if prev_node:
                    legend.edge(prev_node, root_node, style="invis")
                legend.attr(rank="same")

            # Use NetworkX's layout algorithm
            pos = nx.spring_layout(self)

            # Add nodes with positions and colors
            for node_id, (x, y) in pos.items():
                node_data = self.nodes[node_id]
                # Check if it's a root node (no incoming edges)
                is_root = self.in_degree(node_id) == 0
                style = "filled,dashed" if is_root else "filled"
                
                # Choose color based on node type and override status
                node_type = node_data["node_type"]
                has_override = node_data.get("has_manual_override", False)
                fillcolor = override_colors[node_type] if has_override else colors[node_type]
                
                # Create node with appropriate styling
                node_attrs = {
                    'style': f'{style},rounded',  # Add rounded style to existing style
                    'fillcolor': fillcolor,
                    'shape': 'box',  # Keep box shape for text wrapping
                    'fontsize': '10',  # Slightly smaller font
                    'margin': '0.2,0.1',  # Add some margin for text
                    'label': self._wrap_label(node_data['entity'].name),  # Wrap long labels
                }
                
                dot.node(
                    node_id,
                    **node_attrs
                )

            # Add edges with construction rule labels
            for u, v, data in self.edges(data=True):
                if data["relation"] == RelationType.USED_IN_CONSTRUCTION:
                    # Get the construction step that created the target node
                    target_data = self.nodes[v]
                    if "construction_step" in target_data:
                        rule_name = target_data["construction_step"].rule.name
                        dot.edge(u, v, label=rule_name, fontsize="8")
                    else:
                        dot.edge(u, v)

            if output_file:
                # Strip .png extension if present to avoid double extension
                if output_file.lower().endswith('.png'):
                    output_file = output_file[:-4]
                print(f"Rendering construction tree to: {output_file}")
                
                # Use a timeout-protected approach for rendering
                import subprocess
                import shutil
                import os
                from pathlib import Path
                
                # Save the DOT source to a file
                dot_file = os.path.join(temp_dir, "construction_tree.dot")
                with open(dot_file, 'w') as f:
                    f.write(dot.source)
                
                # Run dot command directly with timeout
                png_file = os.path.join(temp_dir, "construction_tree.png")
                try:
                    # Run the command with a 10-second timeout
                    result = subprocess.run(
                        ["dot", "-Tpng", dot_file, "-o", png_file],
                        timeout=10,
                        check=True,
                        capture_output=True
                    )
                    
                    # If successful, copy the file to the final destination
                    if os.path.exists(png_file):
                        # Create the target directory if it doesn't exist
                        Path(output_file).parent.mkdir(parents=True, exist_ok=True)
                        shutil.copy2(png_file, f"{output_file}.png")
                        print("Construction tree visualization complete")
                    else:
                        print(f"Warning: Failed to create visualization - output file not found")
                except subprocess.TimeoutExpired:
                    print("Warning: Visualization timed out after 10 seconds")
                except subprocess.CalledProcessError as e:
                    print(f"Warning: Visualization command failed: {e}")
                except Exception as e:
                    print(f"Warning: Error during visualization: {e}")
                finally:
                    # Clean up temporary directory
                    try:
                        shutil.rmtree(temp_dir)
                    except Exception:
                        pass  # Ignore errors in cleanup
        except Exception as e:
            print(f"Error creating visualization: {e}")

    def visualize_concept_map(self, root_id: str, depth: int = 2, output_file: str = None):
        """
        Visualize a concept map centered around a specific entity.
        
        Args:
            root_id: ID of the central entity
            depth: How many steps to traverse from the root
            output_file: Path to save the visualization
        """
        if output_file:
            print(f"\nGenerating concept map visualization for {root_id}: {output_file}")
        
        dot = Digraph(comment=f'Concept Map for {root_id}', format='png')
        dot.attr(rankdir='TB')
        
        visited = set()

        def add_related_nodes(node_id: str, current_depth: int):
            if current_depth > depth or node_id in visited:
                return

            visited.add(node_id)
            node_data = self.nodes[node_id]

            # Add this node
            colors = {
                NodeType.CONCEPT: "lightblue",
                NodeType.CONJECTURE: "lightgreen",
                NodeType.THEOREM: "lightpink",
            }
            dot.node(
                node_id,
                node_data["entity"].name,
                style="filled",
                fillcolor=colors[node_data["node_type"]],
            )

            # Add edges to related nodes
            for neighbor in self.neighbors(node_id):
                if neighbor not in visited:
                    edge_data = self.get_edge_data(node_id, neighbor)
                    dot.edge(node_id, neighbor, label=edge_data["relation"].name)
                    add_related_nodes(neighbor, current_depth + 1)

        add_related_nodes(root_id, 0)

        if output_file:
            dot.render(output_file, view=False)

    # Persistence methods
    def _serialize_node_data(self, data):
        """Convert node data to a JSON-serializable format."""
        serialized = {}
        for key, value in data.items():
            if key == "entity":
                # Store only the entity type and name for now
                serialized[key] = {"type": value.__class__.__name__, "name": value.name}
            elif key == "node_type":
                serialized[key] = value.value
            elif key == "construction_step":
                if value is not None:
                    serialized[key] = {
                        "rule": value.rule.__class__.__name__,
                        "input_node_ids": value.input_node_ids,
                        "parameters": value.parameters,
                        "timestamp": value.timestamp.isoformat(),
                    }
                else:
                    serialized[key] = None
            else:
                serialized[key] = value
        return serialized

    def _serialize_edge_data(self, data):
        """Convert edge data to a JSON-serializable format."""
        serialized = {}
        for key, value in data.items():
            if key == "relation":
                serialized[key] = value.name
            else:
                serialized[key] = value
        return serialized

    def save(self, filepath: str):
        """Save the entire graph using dill.

        This saves the complete graph structure including all concepts and their implementations.
        Dill is used instead of pickle because it can handle local lambda functions.
        
        Args:
            filepath: Path where the graph should be saved
        """
        # Create the directory if it doesn't exist
        os.makedirs(os.path.dirname(filepath), exist_ok=True)

        # Add .dill extension if not present
        if not filepath.endswith(".dill"):
            filepath = f"{filepath}.dill"

        # Save using dill
        with open(filepath, "wb") as f:
            dill.dump(self, f)

    @classmethod
    def load(cls, filepath: str) -> "KnowledgeGraph":
        """Load the complete graph from a dill file."""
        # Add .dill extension if not present
        if not filepath.endswith(".dill"):
            filepath = f"{filepath}.dill"

        # Load using dill
        with open(filepath, "rb") as f:
            return dill.load(f)

    def copy(self) -> 'KnowledgeGraph':
        """Create a deep copy of the knowledge graph, including counter attributes."""
        # Create a new instance
        new_graph = KnowledgeGraph()
        
        # Copy nodes and edges using the parent class's copy method
        for node, data in self.nodes(data=True):
            new_graph.add_node(node, **data)
        
        for u, v, data in self.edges(data=True):
            new_graph.add_edge(u, v, **data)
        
        # Copy counter attributes
        new_graph._concept_counter = self._concept_counter
        new_graph._conjecture_counter = self._conjecture_counter
        new_graph._theorem_counter = self._theorem_counter
        
        return new_graph

    def get_all_concepts(self) -> List[str]:
        """Get all concept nodes in the graph."""
        return [
            node_id
            for node_id, data in self.nodes(data=True)
            if data["node_type"] == NodeType.CONCEPT
        ]

    def get_all_conjectures(self) -> List[str]:
        """Get all conjecture nodes in the graph."""
        return [
            node_id
            for node_id, data in self.nodes(data=True)
            if data["node_type"] == NodeType.CONJECTURE
        ]

    def get_all_theorems(self) -> List[str]:
        """Get all theorem nodes in the graph."""
        return [
            node_id
            for node_id, data in self.nodes(data=True)
            if data["node_type"] == NodeType.THEOREM
        ]
