"""Specification file for FunSearch interestingness functions."""

import re
from typing import Dict, Any, Optional, List, Union

class InterestingnessSpec:
    """Specification for interestingness function to use with FunSearch."""
    
    @staticmethod
    def get_function_signature() -> str:
        """
        Returns the function signature for the interestingness function.
        
        Returns:
            A string containing the function signature
        """
        return (
            "def calculate_interestingness(knowledge_graph, entity_name: str, relation_name: str, "
            "target_entity_name: str, context: Dict[str, Any] = None) -> float:"
        )
    
    @staticmethod
    def get_function_header() -> str:
        """
        Returns the full header (signature + docstring) for the interestingness function.
        
        Returns:
            A string containing the function header
        """
        signature = InterestingnessSpec.get_function_signature()
        docstring = '''
    """Calculate the interestingness score for a potential relation between entities.
    
    Args:
        knowledge_graph: The knowledge graph (KnowledgeGraph object)
        entity_name: The source entity name
        relation_name: The relation name
        target_entity_name: The target entity name
        context: Optional context with additional information
        
    Returns:
        A float value representing the interestingness score (higher is more interesting)
    """
'''
        return signature + docstring
    
    @staticmethod
    def get_default_function() -> str:
        """
        Returns a default implementation of the interestingness function.
        
        Returns:
            A string containing the full function implementation
        """
        header = InterestingnessSpec.get_function_header()
        implementation = """
    # Simple implementation - can be replaced by FunSearch
    # Higher values indicate more interesting relations
    
    # Default score
    score = 0.0
    
    # If context is not provided, initialize empty dictionary
    if context is None:
        context = {}
    
    # Get information about the entities and relation
    source_entity = knowledge_graph.get_entity(entity_name)
    target_entity = knowledge_graph.get_entity(target_entity_name)
    
    if source_entity is None or target_entity is None:
        return 0.0  # One or both entities don't exist
    
    # 1. Prefer relations that create connections between disconnected components
    source_neighbors = set(knowledge_graph.get_entity_neighbors(entity_name))
    if target_entity_name not in source_neighbors:
        score += 1.0
    
    # 2. Prefer entities with fewer existing relations (to balance the graph)
    source_relations_count = len(source_neighbors)
    target_relations_count = len(knowledge_graph.get_entity_neighbors(target_entity_name))
    
    # Normalize by the total number of entities
    total_entities = len(knowledge_graph.get_all_entity_names())
    if total_entities > 1:
        sparsity_factor = 1.0 - (source_relations_count + target_relations_count) / (2 * (total_entities - 1))
        score += sparsity_factor
    
    # 3. Consider the entity types
    if hasattr(source_entity, 'type') and hasattr(target_entity, 'type'):
        if source_entity.type != target_entity.type:
            # Encourage connections between different types
            score += 0.5
    
    # 4. Consider relation frequency - prefer less common relations
    relation_count = knowledge_graph.count_relation_occurrences(relation_name)
    total_relations = knowledge_graph.count_total_relations()
    
    if total_relations > 0:
        rarity_score = 1.0 - (relation_count / total_relations)
        score += rarity_score
    
    return score
"""
        return header + implementation
    
    @staticmethod
    def get_imports() -> str:
        """
        Returns the import statements needed for the interestingness function.
        
        Returns:
            A string containing the import statements
        """
        return """
import math
from typing import Dict, Any, List, Set
"""
    
    @staticmethod
    def validate_function(function_code: str) -> bool:
        """
        Validates that the function code adheres to the expected signature.
        
        Args:
            function_code: The Python code string to validate
            
        Returns:
            True if the function code is valid, False otherwise
        """
        # Check if the function code contains the expected function name
        if "def calculate_interestingness(" not in function_code:
            return False
        
        # Check parameters using a regex pattern
        pattern = r"def\s+calculate_interestingness\s*\(\s*knowledge_graph\s*,\s*entity_name\s*:\s*str\s*,\s*relation_name\s*:\s*str\s*,\s*target_entity_name\s*:\s*str\s*,\s*context\s*:\s*Dict\[\s*str\s*,\s*Any\s*\]\s*=\s*None\s*\)\s*->\s*float\s*:"
        if not re.search(pattern, function_code):
            return False
        
        return True
    
    @staticmethod
    def create_full_function_code() -> str:
        """
        Creates a complete function code with imports, function signature, and implementation.
        
        Returns:
            A string containing the complete function code
        """
        imports = InterestingnessSpec.get_imports()
        function_code = InterestingnessSpec.get_default_function()
        
        return imports + function_code

# Create a global instance of the calculate_interestingness function for easier importing
calculate_interestingness = InterestingnessSpec.get_default_function() 