"""DSL primitives for learning interestingness functions.

This module defines the basic functions (primitives) that can be combined to 
create more complex interestingness functions. These primitives operate on 
the knowledge graph and specific entity_ids.
"""

from typing import Any, List, Dict, Tuple, Set, Optional, Union
import numpy as np

from frame.knowledge_base.knowledge_graph import KnowledgeGraph, NodeType
from frame.knowledge_base.entities import ConceptType, ExampleType, Entity, Theorem, Example

# --- Graph Structure Primitives ---

def get_ancestors(entity_id: str, graph: KnowledgeGraph) -> List[str]:
    """Returns the list of ancestor nodes."""
    try:
        return list(graph.get_dependencies(entity_id))
    except Exception:
        raise ValueError(f"Error getting ancestors for entity {entity_id}")
        # return []

def get_descendants(entity_id: str, graph: KnowledgeGraph) -> List[str]:
    """Returns the list of descendant nodes."""
    try:
        return list(graph.get_dependents(entity_id))
    except Exception:
        raise ValueError(f"Error getting descendants for entity {entity_id}")
        # return []

def get_direct_descendants(entity_id: str, graph: KnowledgeGraph) -> List[str]:
    """Returns the list of direct descendant nodes."""
    try:
        return list(graph.successors(entity_id))
    except Exception:
        raise ValueError(f"Error getting direct_descendants for entity {entity_id}")
        # return []
    
def get_direct_ancestors(entity_id: str, graph: KnowledgeGraph) -> List[str]:
    """Returns the list of direct ancestor nodes."""
    try:
        return list(graph.predecessors(entity_id))
    except Exception:
        raise ValueError(f"Error getting direct_ancestors for entity {entity_id}")
        # return []

def get_construction_depth(entity_id: str, graph: KnowledgeGraph) -> int:
    """Returns the longest path from a root node."""
    try:
        return graph.construction_depth(entity_id)
    except Exception:
        raise ValueError(f"Error getting construction_depth for entity {entity_id}")
        # return 0

def get_in_degree(entity_id: str, graph: KnowledgeGraph) -> int:
    """Returns the number of direct parent nodes."""
    try:
        degree_info = graph.in_degree(entity_id)
        # Handle case where graph might return a view ({node: degree}) instead of int
        if isinstance(degree_info, int):
            return degree_info
        elif hasattr(degree_info, 'get'): # Check if it's dict-like (e.g., InDegreeView)
            raise ValueError(f"Unexpected type for in_degree: {type(degree_info)}")
            return degree_info.get(entity_id, 0) # Get degree for our specific node
        else:
            # Log unexpected type? For now, return 0
            raise ValueError(f"Unexpected type for in_degree: {type(degree_info)}")
    except Exception as e:
        raise ValueError(f"Error getting in_degree for entity {entity_id}: Error is {e}")

def get_out_degree(entity_id: str, graph: KnowledgeGraph) -> int:
    """Returns the number of direct child nodes."""
    try:
        degree_info = graph.out_degree(entity_id)
        # Handle case where graph might return a view ({node: degree}) instead of int
        if isinstance(degree_info, int):
            return degree_info
        elif hasattr(degree_info, 'get'): # Check if it's dict-like (e.g., OutDegreeView)
            return degree_info.get(entity_id, 0) # Get degree for our specific node
        else:
            # Log unexpected type? For now, return 0
            raise ValueError(f"Unexpected type for out_degree: {type(degree_info)}")
    except Exception as e:
        raise ValueError(f"Error getting out_degree for entity {entity_id}: Error is {e}")

def get_construction_history_rule_names(entity_id: str, graph: KnowledgeGraph) -> List[str]:
    """
    Returns a list of rule names used in the construction history of this entity.
    The list goes from the earliest ancestor rule to the rule that created this entity.
    Note: This assumes an acyclic construction graph. Complex histories or cycles
          might require more sophisticated handling.
    Potential usability issue: FunSearch typically prefers numerical/boolean features.
    Consider returning numerical summaries (e.g., rule counts) in the future.
    """
    history = []
    nodes_to_visit = [(entity_id, [])] # (node_id, path_to_node)
    visited_paths = set()

    # Store paths found ending at each node to reconstruct later
    paths_found = {entity_id: []}

    # BFS from target node upwards
    queue = [(entity_id, [entity_id])]
    visited_bfs = {entity_id}
    final_paths = []

    while queue:
        current_id, path = queue.pop(0)

        try:
            node_data = graph.nodes[current_id]
            step = node_data.get('construction_step')

            is_root = True
            if step and step.input_node_ids:
                is_root = False
                for parent_id in step.input_node_ids:
                    if parent_id not in visited_bfs:
                         visited_bfs.add(parent_id)
                         new_path = [parent_id] + path
                         queue.append((parent_id, new_path))
                    elif parent_id in graph: # Check if parent exists to avoid errors
                         # If visited, still add path possibility if shorter path exists? No, BFS ensures shortest.
                         pass # Avoid re-adding visited nodes

            if is_root:
                # Found a path from a root to the target
                final_paths.append(path)

        except KeyError:
            continue # Node not found, path terminates here
        except Exception as e:
            # print(f"Error processing node {current_id} in history search: {e}") # Optional debug
            continue

    if not final_paths:
        return []

    # Find the longest path (more steps = potentially more complete history)
    longest_path_nodes = max(final_paths, key=len)

    # Extract rule names from the longest path
    rule_names = []
    for node_id in longest_path_nodes:
        try:
             step = graph.nodes[node_id].get('construction_step')
             if step:
                 rule_names.append(step.rule.name)
        except KeyError:
            continue

    return rule_names # Return names in order from root to entity

def get_entity_step_age(entity_id: str, graph: KnowledgeGraph) -> int:
    """
    Returns the age of an entity in terms of construction steps in the knowledge graph.
    
    Age is measured as the difference between the current step counter and 
    the step at which the entity was created. This provides a measure of how
    "old" the entity is relative to the current state of knowledge construction.
    
    Returns:
        int: The number of construction steps since this entity was created,
             or 0 if the entity is not found or an error occurs.
    """
    try:
        return graph.get_entity_step_age(entity_id)
    except Exception:
        raise ValueError(f"Error getting entity_step_age for entity {entity_id}")
        # return 0

def get_num_concepts(entity_id = None, graph: KnowledgeGraph = None) -> int:
    """Returns the number of concepts in the graph.
    
    Args:
        entity_id: Optional entity ID. If provided, only counts concepts in the entity's subgraph.
                   If None, counts all concepts in the graph.
        graph: The knowledge graph to analyze
    """
    if graph is None:
        if entity_id is None:
            raise ValueError("Either entity_id or graph must be provided")
        else:
            graph = entity_id # passed the wrong way byGPT
    if graph is not None:
        # Original behavior - count all concepts in graph
        return 1 + len([node for node in graph.nodes if graph.nodes[node]['node_type'] == NodeType.CONCEPT])

def get_num_conjectures(entity_id = None, graph: KnowledgeGraph = None) -> int:
    """Returns the number of conjectures in the graph.
    
    Args:
        entity_id: Optional entity ID. If provided, only counts conjectures in the entity's subgraph.
                   If None, counts all conjectures in the graph.
        graph: The knowledge graph to analyze
    """
    if graph is None:
        if entity_id is None:
            raise ValueError("Either entity_id or graph must be provided")
        else:
            graph = entity_id # passed the wrong way byGPT
    if graph is not None:
        # Original behavior - count all conjectures in graph
        return 1 + len([node for node in graph.nodes if graph.nodes[node]['node_type'] == NodeType.CONJECTURE])

# --- Entity Attribute Primitives ---

# Note(_; 3/31): This is semi-placeholder for now, unclear that we want a method for this at present.
def get_entity_node_type(entity_id: str, graph: KnowledgeGraph) -> float:
    """Returns a numeric representation of the node type (e.g., Concept=1, Conj=2, Thm=3)."""
    try:
        node_type = graph.nodes[entity_id]['node_type']
        if node_type == NodeType.CONCEPT: return "Concept"
        if node_type == NodeType.CONJECTURE: return "Conjecture"
        if node_type == NodeType.THEOREM: return "Theorem"
        return "Unknown"
    except Exception:
        raise ValueError(f"Error getting entity_node_type for entity {entity_id}")
        # return 0.0

def get_concept_category(entity_id: str, graph: KnowledgeGraph) -> float:
    """Returns a numeric representation of the concept type (e.g., Predicate=1, Function=2) if applicable."""
    try:
        entity = graph.nodes[entity_id]['entity']
        if hasattr(entity, 'example_structure') and entity.example_structure:
            concept_type = entity.example_structure.concept_type
            if concept_type == ConceptType.PREDICATE: return "Predicate"
            if concept_type == ConceptType.FUNCTION: return "Function"
            if concept_type == ConceptType.RELATION: return "Relation"
            if concept_type == ConceptType.CONSTANT: return "Constant"
        return "Unknown"
    except Exception:
        raise ValueError(f"Error getting concept_category for entity {entity_id}")
        # return 0.0

def get_input_arity(entity_id: str, graph: KnowledgeGraph) -> int:
    """Returns the input arity of the concept, if applicable."""
    try:
        entity = graph.nodes[entity_id]['entity']
        if hasattr(entity, 'example_structure') and entity.example_structure and entity.example_structure.input_arity is not None:
            return entity.example_structure.input_arity
        return 0
    except Exception:
        raise ValueError(f"Error getting input_arity for entity {entity_id}")
        # return 0

def get_num_component_types(entity_id: str, graph: KnowledgeGraph) -> int:
    """Returns the number of component types in the entity's example structure."""
    try:
        entity = graph.nodes[entity_id]['entity']
        if hasattr(entity, 'example_structure') and entity.example_structure:
            return len(entity.example_structure.component_types)
        return 0
    except Exception:
        raise ValueError(f"Error getting num_component_types for entity {entity_id}")
        # return 0

# --- Example/Non-Example Primitives ---

def get_examples(entity_id: str, graph: KnowledgeGraph) -> List[Example]:
    """Returns the list of positive examples."""
    try:
        entity = graph.nodes[entity_id]['entity']
        if hasattr(entity, 'examples'):
            return [example.value for example in entity.examples.get_examples()]
        return []
    except Exception:
        raise ValueError(f"Error getting examples for entity {entity_id}")
        # return []

def get_nonexamples(entity_id: str, graph: KnowledgeGraph) -> List[Example]:
    """Returns the list of negative examples."""
    try:
        entity = graph.nodes[entity_id]['entity']
        if hasattr(entity, 'examples'):
            return [example.value for example in entity.examples.get_nonexamples()]
        return []
    except Exception:
        raise ValueError(f"Error getting nonexamples for entity {entity_id}")
        # return []

# --- Construction Step Primitives ---

def get_num_construction_inputs(entity_id: str, graph: KnowledgeGraph) -> int:
    """Returns the number of direct inputs used in construction, if applicable."""
    try:
        step = graph.nodes[entity_id].get('construction_step')
        return len(step.input_node_ids) if step else 0
    except Exception:
        raise ValueError(f"Error getting num_construction_inputs for entity {entity_id}")
        # return 0

# --- Conjecture/Theorem Primitives ---

def is_proven(entity_id: str, graph: KnowledgeGraph) -> float:
    """Returns 1.0 if the entity is a proven theorem, 0.0 otherwise."""
    try:
        node_data = graph.nodes[entity_id]
        return 1.0 if node_data['node_type'] == NodeType.THEOREM else 0.0
    except Exception:
        raise ValueError(f"Error getting is_proven for entity {entity_id}")
        # return 0.0

# --- General Primitives ---

def create_weighted_interestingness_function(
    functions: List[callable], 
    weights: List[float]
) -> callable:
    """
    Create a weighted combination of multiple interestingness functions.
    
    Args:
        functions: List of interestingness functions, each taking (entity_id, graph)
        weights: List of weights for each function
        
    Returns:
        A callable function that computes the weighted average of the input functions
        
    Raises:
        ValueError: If the lengths of functions and weights don't match
    """
    if len(functions) != len(weights):
        raise ValueError("Number of functions must match number of weights")
    
    def weighted_interestingness(entity_id: str, graph: KnowledgeGraph) -> float:
        """Calculate weighted interestingness score."""
        total_score = 0.0
        total_weight = sum(weights)
        
        for i, func in enumerate(functions):
            try:
                score = func(entity_id, graph)
                total_score += weights[i] * score
            except Exception as e:
                # Skip this function if it fails
                continue
                
        # Normalize by total weight
        if total_weight > 0:
            return total_score / total_weight
        return 0.0
        
    return weighted_interestingness

# Define the list of all primitives that can be used in learned programs
ALL_PRIMITIVES = [
    # Graph Structure Primitives
    get_ancestors, get_descendants, get_construction_depth,
    get_in_degree, get_out_degree, get_construction_history_rule_names, 
    get_entity_step_age,
    get_num_concepts, get_num_conjectures,
    # Entity Attribute Primitives
    get_entity_node_type, get_concept_category, get_input_arity,
    get_num_component_types,
    
    # Example/Non-Example Primitives
    get_examples, get_nonexamples,
    
    # Construction Step Primitives
    get_num_construction_inputs,

    # Conjecture/Theorem Primitives
    is_proven,

    # General Primitives
    create_weighted_interestingness_function,
]

# --- HR Interestingness Functions ---

# (Concepts)
# Note(_; 4/9): Conjectural Comprehensibility is the same as Concept Comprehensibility.
def recreate_comprehensibility(entity_id: str, graph: KnowledgeGraph) -> float:
    """Recreate the HR Comprehensibility measure using DSL primitives.
    Computes the reciprocal of the number of concepts in the construction history.
    """
    # Return default score if entity doesn't exist
    if entity_id not in graph.nodes:
        return 0.0
        
    if graph.nodes[entity_id]['node_type'] == NodeType.CONCEPT:
        num_concepts = len({node for node in get_ancestors(entity_id, graph) if graph.nodes[node]['node_type'] == NodeType.CONCEPT})
        # Add 1 to avoid division by zero
        return 1.0 / (1.0 + num_concepts)  # Using standard Python division
    else:
        return 0.0

def recreate_parsimony(entity_id: str, graph: KnowledgeGraph) -> float:
    """Recreate the HR Parsimony measure using DSL primitives.
    Computes the reciprocal of the number of component types in the entity's example structure.
    """
    # Return default score if entity doesn't exist
    if entity_id not in graph.nodes:
        return 0.0
        
    if graph.nodes[entity_id]['node_type'] == NodeType.CONCEPT:
        # Parsimony is inversely proportional to number of component types
        num_comp_types = get_num_component_types(entity_id, graph)
        # Add 1 to avoid division by zero
        return 1.0 / (1.0 + num_comp_types)  # Using standard Python division
    else:
        return 0.0

def recreate_applicability(entity_id: str, graph: KnowledgeGraph) -> float:
    """Recreate the HR Applicability measure using DSL primitives.
    Computes the proportion of instances which are examples over the sum of examples and nonexamples.
    """
    if entity_id not in graph.nodes:
        return 0.0
        
    if graph.nodes[entity_id]['node_type'] == NodeType.CONCEPT:
        examples = get_examples(entity_id, graph)
        nonexamples = get_nonexamples(entity_id, graph)
        # Add 1 to avoid division by zero
        return float(len(examples)) / (len(examples) + len(nonexamples) + 1.0)  # Using standard Python division
    else:
        return 0.0

def recreate_novelty(entity_id: str, graph: KnowledgeGraph) -> float:
    """Recreate the HR Novelty measure using DSL primitives.
    Computes the reciprocal of the number of concepts (including the concept itself) which
    share the same example categorization as the entity.
    """
    # Return default score if entity doesn't exist
    if entity_id not in graph.nodes:
        return 0.0
        
    # Get the example categorization of the entity
    entity_examples = get_examples(entity_id, graph)
    entity_nonexamples = get_nonexamples(entity_id, graph)
    if graph.nodes[entity_id]['node_type'] == NodeType.CONCEPT:
        # iterate through graph concepts
        count, total = 0, 0
        # Iterate through node IDs (keys of graph.nodes)
        for concept_id in graph.nodes:
            # Access node data using the ID
            node_data = graph.nodes[concept_id]
            if node_data['node_type'] == NodeType.CONCEPT:
                total += 1
                # Check if the current concept_id is the entity_id itself to avoid redundant get calls
                if concept_id == entity_id:
                    # We already have entity_examples and entity_nonexamples
                    # Check if the current entity has any examples/nonexamples to compare
                    if not entity_examples and not entity_nonexamples:
                        # If entity has no examples, only match other concepts with no examples
                        concept_examples = get_examples(concept_id, graph)
                        concept_nonexamples = get_nonexamples(concept_id, graph)
                        if not concept_examples and not concept_nonexamples:
                            count += 1
                    else:
                        concept_examples = get_examples(concept_id, graph)
                        concept_nonexamples = get_nonexamples(concept_id, graph)
                        # Note(_; 4/9): A relaxation of the HR requirement but we are holding concepts differently.
                        if concept_examples == entity_examples and concept_nonexamples == entity_nonexamples:
                            count += 1
        return float(count) / (total + 1.0)
    else:
        return 0.0

# Note(_; 4/9): This might be changed to outdegree, wording in HR is unclear.
def recreate_productivity(entity_id: str, graph: KnowledgeGraph) -> float:
    """Recreate the HR Productivity measure using DSL primitives.
    Computes the number of descendants of the entity compared to number of steps performed after the entity was created.
    """
    if entity_id not in graph.nodes:
        return 0.0
        
    if graph.nodes[entity_id]['node_type'] == NodeType.CONCEPT:
        num_descendants = len(get_descendants(entity_id, graph))
        num_steps = get_entity_step_age(entity_id, graph)
        return float(num_descendants) / (num_steps + 1.0)
    else:
        return 0.0

def recreate_num_conjectures_appearing(entity_id: str, graph: KnowledgeGraph) -> float:
    """Returns the number of conjectures that are directly descended from the entity."""
    if entity_id not in graph.nodes:
        return 0.0
        
    if graph.nodes[entity_id]['node_type'] == NodeType.CONCEPT:
        descendants = get_descendants(entity_id, graph) 
        count = 0
        for descendant in descendants:
            # direct descendants only
            if graph.nodes[descendant]['node_type'] == NodeType.CONJECTURE:
                if entity_id in graph.nodes[descendant].get('input_node_ids', []):
                    count += 1
        return count
    else:
        return 0.0

# (Conjectures)
def recreate_surprisingness(entity_id: str, graph: KnowledgeGraph) -> float:
    """Recreate the HR Surprisingness measure using DSL primitives.
    For equivalence and implication conjectures, computes the number of concepts 
    appearing in exactly one of the concept's construction histories.
    """
    if entity_id not in graph.nodes:
        return 0.0
        
    if graph.nodes[entity_id]['node_type'] == NodeType.CONJECTURE:
        if graph.nodes[entity_id]['entity'].construction_step.rule.name in ["equivalence", "implication"]:
            # get construction history of the two concepts
            parent1, parent2 = get_direct_ancestors(entity_id, graph)
            concept1_history = get_ancestors(parent1, graph)
            concept2_history = get_ancestors(parent2, graph)
            # count number of concepts that appear in only one of the two construction histories
            return len(set(concept1_history) ^ set(concept2_history))
        else:
            return 0.0
    else:
        return 0.0

def recreate_conjectural_applicability(entity_id: str, graph: KnowledgeGraph) -> float:
    """Recreate the HR Conjectural Applicability measure using DSL primitives.
    Computes the proportion of instances which are examples over the sum of examples and nonexamples.
    """
    if entity_id not in graph.nodes:
        return 0.0
        
    if graph.nodes[entity_id]['node_type'] == NodeType.CONJECTURE:
        conjecture_type = graph.nodes[entity_id]['entity'].construction_step.rule.name
        parents = get_direct_ancestors(entity_id, graph)
        if len(parents) == 0:
            raise ValueError(f"Conjecture {entity_id} has no parents")
        if conjecture_type == "equivalence":
            return recreate_applicability(parents[0], graph)
        elif conjecture_type == "implication":
            return recreate_applicability(parents[0], graph)
        else:
            return 0.0
    else:
        return 0.0
    
# TODO(_; 4/9): Eventually we can add a proof length measure for theorems.

# --- Example Interestingness Functions ---

# TODO(_; 4/9): Make configurable.
HR_WEIGHTS = [0.17, 0.17, 0.17, 0.17, 0.17, 0.17]
HR_INTERESTINGNESS_FUNCTION = create_weighted_interestingness_function(
    [recreate_comprehensibility, recreate_parsimony, recreate_applicability, recreate_novelty, recreate_productivity, recreate_conjectural_applicability],
    HR_WEIGHTS
)

def example_complex_interestingness(entity_id: str, graph: KnowledgeGraph) -> float:
    """
    Example of a more complex interestingness function that combines multiple factors:
    - Comprehensibility (construction depth)
    - Parsimony (number of component types)
    - Example density
    - Concept utility (measured by descendants)
    - Proven status (for conjectures/theorems)
    
    This is just an example of what a learned function might look like.
    """
    # Recreate basic measures
    comprehensibility = recreate_comprehensibility(entity_id, graph)
    parsimony = recreate_parsimony(entity_id, graph)
    
    # Add entity-specific factors
    example_density = recreate_applicability(entity_id, graph)
    proven_status = is_proven(entity_id, graph)
    
    # Compute utility as logarithm of 1 + number of descendants
    # (concepts with more descendants are more useful building blocks)
    num_descendants = len(get_descendants(entity_id, graph))
    # Using standard Python math instead of protected_log
    import math
    try:
        utility = math.log(1.0 + num_descendants)
    except (ValueError, OverflowError):
        utility = 0.0  # Handle errors gracefully
    
    # Combine the factors with weights
    # Give higher weight to proven theorems
    return (0.2 * comprehensibility + 
            0.1 * parsimony + 
            0.2 * example_density + 
            0.3 * utility + 
            0.2 * proven_status)

# Examples of how the DSL can be used to create interestingness functions
EXAMPLE_LEARNED_FUNCTIONS = {
    "comprehensibility": recreate_comprehensibility,
    "parsimony": recreate_parsimony,
    "applicability": recreate_applicability,
    "novelty": recreate_novelty,
    "productivity": recreate_productivity,
    "num_conjectures_appearing": recreate_num_conjectures_appearing,
    "complex_example": example_complex_interestingness,
    "hr": HR_INTERESTINGNESS_FUNCTION,
}
