"""
Storage for ground truth mathematical entities in the FRAME system.

This module maintains a mapping between discovered mathematical concepts/conjectures
and their canonical names and implementations. When a policy discovers a concept
that matches a known mathematical entity, we can:
1. Rename it to its canonical name
2. Optionally update its computational implementation to match the ground truth
3. Track which names can be discovered for this entity

The structure is designed to minimize redundancy by:
1. Using a single canonical name for each entity
2. Tracking which names can be discovered for each entity
3. Providing a way to normalize names before comparison
"""

from typing import Dict, Set, Optional, Callable, Any, Tuple, List
from dataclasses import dataclass
from enum import Enum
import re
import logging

from frame.knowledge_base.entities import Example
from frame.environments.ground_truth_types import GroundTruthEntity, EntityType
from frame.environments.NT.operations import NT_FUNCTIONS, NT_PREDICATES, NT_NUMBERS
from frame.environments.NT.calculations import NT_CALCULATIONS, NT_PREDICATE_CALCULATIONS, NT_EQ_CALCULATIONS
from frame.environments.NT.conjectures import NT_CONJECTURES, NT_PRED_ALWAYS_TRUE

# ANSI color codes
COLORS = {
    'CYAN': '\033[36m',
    'RESET': '\033[0m',  # Reset color
}

def extract_concept_parts(name: str) -> List[str]:
    """
    Extract the parts of a concept name.
    
    For example:
    - "iterate_(add_with_zero)" -> ["iterate", "add_with_zero"]
    - "iterate_(iterate_(successor)_with_zero)" -> ["iterate", "iterate_(successor)_with_zero"]
    """
    # Remove outer parentheses if they exist
    name = name.strip("()")
    
    # Split on first underscore after a word
    parts = re.split(r"(?<=[a-zA-Z])_", name, maxsplit=1)
    if len(parts) == 1:
        return [parts[0]]
    
    # Strip any extra parentheses from the inner name
    parts[1] = parts[1].strip("()")
    return parts

def substitute_concept_name(name: str, name_mapping: Dict[str, str]) -> str:
    """
    Substitute known concept names in a concept application.
    
    Args:
        name: The concept name to substitute in
        name_mapping: Dictionary mapping discovered names to canonical names
        
    Returns:
        The name with substitutions applied
    """
    # First check if the entire name is a known concept
    call_count = 0
    if name in name_mapping:
        return name_mapping[name]
  
    # Check if there are any applied production rules to parse
    if '(' not in name:
        return name

    expr = name.replace("(", "<") # so that we can recurse over nested parentheses
    expr = expr.replace(")", ">")

    # Process the base case - a string with exactly one pair of brackets
    def _process_base_case(expr):
        start_paren_idx = expr.find('<')
        end_paren_idx = expr.find('>')
        
        before_paren = expr[:start_paren_idx]
        inside_paren = expr[start_paren_idx+1:end_paren_idx]

        # the only mappings that can happen are formatted: RULE_(stuff_in_parenthesis)
        underscore_positions = [m.start() for m in re.finditer('_', before_paren)]
        if len(underscore_positions) <= 1:
            rule_idx = 0
        else:
            rule_idx = underscore_positions[-2]+1
        rule = before_paren[rule_idx:]
        innermost = rule + "(" + inside_paren + ")" # revert "<>" to "()"
        proc_innermost = name_mapping.get(innermost, innermost)

        # print(f"base | i {innermost} | p {proc_innermost}")
        return expr[:rule_idx] + proc_innermost + expr[end_paren_idx+1:]

    def _process_level(expr, name_mapping):
        # Guard against excessive recursion
        nonlocal call_count
        call_count += 1
        if call_count > 50:
            print(f"WARNING: Recursion limit reached in substitute_concept_name for: {name}")
            return expr
        open_count = expr.count('<')
        close_count = expr.count('>')

        if open_count == 1 and close_count == 1:
            return _process_base_case(expr)
        
        # Recursion: 
        start_idx = expr.find('<', 0, expr.find('>')) #outer layer
        end_idx = expr.rfind('>', start_idx)

        # Extract the content inside the parentheses
        inner_content = expr[start_idx+1:end_idx]
        
        # Process the inner content recursively
        processed_inner = _process_level(inner_content, name_mapping)
        # print(f"i {inner_content} | p {processed_inner}")

        # Replace innermost with processed content
        new_expr = expr[:start_idx] + "(" + processed_inner + ")" 

        # print(f"new expr {new_expr}")
        result = name_mapping.get(new_expr, new_expr) + expr[end_idx+1:]
        # print(f"result {result}")
        return result

    return _process_level(expr, name_mapping)

def strip_asterisks(name: str) -> str:
    """
    Strip trailing asterisks from a name.
    
    Args:
        name: The name to strip asterisks from
        
    Returns:
        The name with trailing asterisks removed
    """
    return name.rstrip('*')

def normalize_name(name: str) -> str:
    """
    Normalize a concept name by recursively substituting known concept names.
    
    This function:
    1. Strips any trailing asterisks from the name
    2. Creates a mapping of all discovered names to their canonical names
    3. Recursively substitutes known concept names in the input name
    4. Returns the fully substituted name
    
    For example:
    If we know:
    - "iterate_(successor)" -> "add"
    - "iterate_(add_with_zero)" -> "multiply"
    
    Then:
    - "iterate_(iterate_(successor)_with_zero)*" -> "iterate_(add_with_zero)" -> "multiply"
    
    Args:
        name: The concept name to normalize
        
    Returns:
        The normalized name with all known substitutions applied
    """
    # Strip any trailing asterisks first
    name = strip_asterisks(name)
    
    # Create mapping of discovered names to canonical names
    name_mapping = {}
    for entity in GROUND_TRUTH_ENTITIES.values():
        for discovered_name in entity.discovered_names:
            name_mapping[discovered_name] = entity.canonical_name
    
    # Recursively substitute known names
    return substitute_concept_name(name, name_mapping)

# Dictionary mapping normalized names to ground truth entities
GROUND_TRUTH_ENTITIES: Dict[str, GroundTruthEntity] = {
    **NT_FUNCTIONS,
    **NT_PREDICATES,
    **NT_NUMBERS,
    **NT_CALCULATIONS,
    **NT_EQ_CALCULATIONS,
    **NT_PREDICATE_CALCULATIONS,
    **NT_CONJECTURES,
    **NT_PRED_ALWAYS_TRUE
}

def get_ground_truth_entity(name: str) -> Optional[GroundTruthEntity]:
    """
    Get the ground truth entity for a given name, if it exists.
    
    This function:
    1. Normalizes the input name by recursively substituting known concept names
    2. Looks up the normalized name in the ground truth entities dictionary
    
    For example:
    If we know "iterate_(successor)" -> "add", then:
    - "iterate_(iterate_(successor)_with_zero)" will be normalized to "iterate_(add_with_zero)"
    - This will match the discovered name for "multiply"
    
    Args:
        name: The concept name to look up
        
    Returns:
        The ground truth entity if found, None otherwise
    """
    normalized_name = normalize_name(name)
    return GROUND_TRUTH_ENTITIES.get(normalized_name)

def is_ground_truth_entity(name: str) -> bool:
    """Check if a name corresponds to a ground truth entity."""
    return normalize_name(name) in GROUND_TRUTH_ENTITIES

def get_canonical_name(name: str) -> Optional[str]:
    """Get the canonical name for a given name, if it exists."""
    entity = get_ground_truth_entity(name)
    return entity.canonical_name if entity else None

def update_entity_implementation(entity: Any, name: str, update_implementation: bool = True) -> bool:
    """
    Update an entity's name and optionally its implementation to match ground truth.
    
    Args:
        entity: The entity to update
        name: The current name of the entity
        update_implementation: Whether to update the computational implementation
        
    Returns:
        bool: True if the entity was updated, False otherwise
    """
    ground_truth = get_ground_truth_entity(name)
    if not ground_truth:
        return False

    logger = logging.getLogger(__name__)

    # Update name to canonical name
    entity.name = ground_truth.canonical_name
    
    # Update implementation if requested and available
    if update_implementation and ground_truth.computational_implementation:
        # Set computational implementation with verification capabilities
        # Default both capabilities to True for ground truth entities
        entity.set_computational_implementation(
            ground_truth.computational_implementation,
            can_add_examples=True,
            can_add_nonexamples=True
        )

        purged_examples = entity.remove_invalid_examples()
        #TODO(_;4/21): Could reassign examples and nonexamples
        if purged_examples[0] or purged_examples[1]:
            print(f"Removed examples: {', '.join(str(f.value) for f in purged_examples[0])}")
            print(f"Removed nonexamples: {', '.join(str(f.value) for f in purged_examples[1])}")
    
    # Note(_; 5/8): only for tau function
    if update_implementation and (ground_truth.new_examples or ground_truth.new_nonexamples):
        for e in ground_truth.new_examples:
            entity.add_example(e, override=True)
        for e in ground_truth.new_nonexamples:
            entity.add_nonexample(e, override=True)   

    if ground_truth.z3_translation and update_implementation:
        entity.set_z3_translation(ground_truth.z3_translation)

    # Logging
    if ground_truth.entity_type == EntityType.CONJECTURE:
        logger.info(f"{COLORS['CYAN']}FOUND GROUND TRUTH CONJECTURE {ground_truth.canonical_name}{COLORS['RESET']}")
    if ground_truth.canonical_name == "tau" or ground_truth.canonical_name == "is_prime":
        logger.info(f"{COLORS['CYAN']}FOUND IMPORTANT GROUND TRUTH ENTITY {ground_truth.canonical_name}{COLORS['RESET']}")

    return True

if __name__ == "__main__":
    import unittest
    from dataclasses import dataclass
    from typing import Callable
    
    class TestGroundTruthEntities(unittest.TestCase):
        def setUp(self):
            # Create a test concept class
            @dataclass
            class TestConcept:
                name: str
                def set_computational_implementation(self, impl: Callable, can_add_examples: bool,
                                                     can_add_nonexamples: bool):
                    self.computational_implementation = impl
                    self.can_add_examples = can_add_examples
                    self.can_add_nonexamples = can_add_nonexamples
            
            self.TestConcept = TestConcept
        
        def test_normalize_name(self):
            """Test the name normalization function with recursive concept substitution."""
            test_cases = [
                # Basic canonical names
                ("add", "add"),
                ("multiply", "multiply"),
                ("power", "power"),
                
                # Direct discovered names
                ("iterate_(successor)", "add"),
                ("iterate_(add_with_zero)", "multiply"),
                ("iterate_(multiply_with_one)", "power"),
                
                # Recursive substitutions
                ("iterate_(iterate_(successor)_with_zero)", "multiply"),
                ("iterate_(iterate_(iterate_(successor)_with_zero)_with_one)", "power"),

                # Non-recursive specializations
                ("specialized_(divides_at_0_to_two)", "is_even"),
                # ("exists_(multiply)", "divides"),
                ("exists_(multiply_indices_[0])", "divides"),
                ("size_of_(divides_indices_[0])", "tau"),
                
                # Constants
                # TODO(_; 4/16): Constant rule in REPL
                # ("constant_zero", "zero"),
                # ("constant_successor", "one"),
                ("specialized_(successor_at_0_to_zero)", "one"),
                
                # Non-existent names should remain unchanged
                ("nonexistent", "nonexistent"),
                ("", ""),
                (" ", " "),
            ]
            
            for input_name, expected in test_cases:
                with self.subTest(input_name=input_name):
                    self.assertEqual(normalize_name(input_name), expected)
        
        def test_get_ground_truth_entity(self):
            """Test retrieving ground truth entities."""
            # Test direct canonical names
            self.assertIsNotNone(get_ground_truth_entity("add"))
            self.assertIsNotNone(get_ground_truth_entity("multiply"))
            self.assertIsNotNone(get_ground_truth_entity("power"))
            
            # Test discovered names
            self.assertIsNotNone(get_ground_truth_entity("iterate_(successor)"))
            self.assertIsNotNone(get_ground_truth_entity("iterate_(add_with_zero)"))
            self.assertIsNotNone(get_ground_truth_entity("iterate_(multiply_with_one)"))
            
            # Test non-existent names
            self.assertIsNone(get_ground_truth_entity("nonexistent"))
            self.assertIsNone(get_ground_truth_entity(""))
            self.assertIsNone(get_ground_truth_entity(" "))
        
        def test_is_ground_truth_entity(self):
            """Test checking if a name is a ground truth entity."""
            # Test canonical names
            self.assertTrue(is_ground_truth_entity("add"))
            self.assertTrue(is_ground_truth_entity("multiply"))
            self.assertTrue(is_ground_truth_entity("power"))
            
            # Test discovered names
            self.assertTrue(is_ground_truth_entity("iterate_(successor)"))
            self.assertTrue(is_ground_truth_entity("iterate_(add_with_zero)"))
            self.assertTrue(is_ground_truth_entity("iterate_(multiply_with_one)"))
            
            # Test non-existent names
            self.assertFalse(is_ground_truth_entity("nonexistent"))
            self.assertFalse(is_ground_truth_entity(""))
            self.assertFalse(is_ground_truth_entity(" "))
        
        def test_get_canonical_name(self):
            """Test getting canonical names for entities."""
            # Test canonical names
            self.assertEqual(get_canonical_name("add"), "add")
            self.assertEqual(get_canonical_name("multiply"), "multiply")
            self.assertEqual(get_canonical_name("power"), "power")
            
            # Test discovered names
            self.assertEqual(get_canonical_name("iterate_(successor)"), "add")
            self.assertEqual(get_canonical_name("iterate_(add_with_zero)"), "multiply")
            self.assertEqual(get_canonical_name("iterate_(multiply_with_one)"), "power")
            
            # Test non-existent names
            self.assertIsNone(get_canonical_name("nonexistent"))
            self.assertIsNone(get_canonical_name(""))
            self.assertIsNone(get_canonical_name(" "))
        
        def test_update_entity_implementation(self):
            """Test updating entity implementations."""
            # Create a test concept
            concept = self.TestConcept(name="iterate_(successor)")
            
            # Test updating with implementation
            self.assertTrue(update_entity_implementation(concept, "iterate_(successor)", update_implementation=True))
            self.assertEqual(concept.name, "add")  # Should be renamed to canonical name
            self.assertIsNotNone(concept.computational_implementation)  # Should have implementation
            
            # Test updating without implementation
            concept = self.TestConcept(name="iterate_(successor)")
            self.assertTrue(update_entity_implementation(concept, "iterate_(successor)", update_implementation=False))
            self.assertEqual(concept.name, "add")  # Should be renamed to canonical name
            self.assertFalse(hasattr(concept, 'computational_implementation'))  # Should not have implementation
            
            # Test non-existent entity
            concept = self.TestConcept(name="nonexistent")
            self.assertFalse(update_entity_implementation(concept, "nonexistent", update_implementation=True))
            self.assertEqual(concept.name, "nonexistent")  # Should not be renamed
            self.assertFalse(hasattr(concept, 'computational_implementation'))  # Should not have implementation
        
        def test_recursive_concept_application(self):
            """Test handling of recursive concept applications."""
            # Test complex cases that should match ground truth entities
            test_cases = [
                ("iterate_(add_with_zero)", "multiply"),
                ("iterate_(iterate_(successor)_with_zero)", "multiply"),
                ("iterate_(iterate_(iterate_(successor)_with_zero)_with_one)", "power"),
                ("specialized_(leq_than_at_0_to_specialized_(successor_at_0_to_zero))", "geq_one"),
                # TODO(_;4/17)
                # ("specialized_(exists_(add_indices_[0])_at_0_to_specialized_(successor_at_0_to_zero))", "geq_one")
            ]
            
            for input_name, expected in test_cases:
                with self.subTest(input_name=input_name):
                    self.assertEqual(get_canonical_name(input_name), expected)
    
    # Run the tests
    unittest.main(verbosity=2) 