"""
Tests for the interestingness function loading mechanism in InterestingnessGuidedPolicy.
"""

import os
import sys
import unittest
import tempfile
from pathlib import Path

# Add the project root to the Python path
project_root = Path(__file__).parent.parent.parent
if str(project_root) not in sys.path:
    sys.path.append(str(project_root))

from frame.policies.interestingness_guided_policy import (
    InterestingnessGuidedPolicy,
    ConceptSelectionStrategy,
    ActionSelectionStrategy
)
from frame.knowledge_base.knowledge_graph import KnowledgeGraph, NodeType
from frame.knowledge_base.entities import (
    Concept, 
    Conjecture, 
    ExampleStructure, 
    ConceptType,
    ExampleType,
    Zero
)

class TestInterestingnessLoading(unittest.TestCase):
    """Test suite for the interestingness function loading mechanism."""
    
    def setUp(self):
        """Set up test environment."""
        # Create a simple test graph with one concept and one conjecture
        self.graph = KnowledgeGraph()
        
        # Create a simple concept (similar to zero_concept from map_iterate.py)
        self.concept = Concept(
            name="TestConcept",
            description="A test concept",
            symbolic_definition=lambda: Zero(),
            computational_implementation=lambda: 0,
            example_structure=ExampleStructure(
                concept_type=ConceptType.CONSTANT,
                component_types=(ExampleType.NUMERIC,),
                input_arity=0,
            ),
        )
        self.concept_id = self.graph.add_concept(self.concept)
        
        # Create a simple conjecture with the necessary parameters
        self.conjecture = Conjecture(
            name="TestConjecture", 
            description="A test conjecture",
            symbolic_definition=lambda: Zero(),  # Simple placeholder
            computational_implementation=lambda: True,  # Simple placeholder
            example_structure=ExampleStructure(
                concept_type=ConceptType.CONSTANT,
                component_types=(ExampleType.NUMERIC,),
                input_arity=0,
            ),
        )
        self.conjecture_id = self.graph.add_conjecture(self.conjecture)
        
        # Path to test programs
        self.test_dir = Path(__file__).parent
        self.test_programs_dir = self.test_dir / "test_programs"
    
    def test_direct_scorer_assignment(self):
        """Test that direct assignment of interestingness_scorer works."""
        # Create a simple scoring function
        def test_scorer(entity_id, graph):
            entity, node_type, _ = graph.get_node(entity_id)
            if node_type == NodeType.CONCEPT:
                return 1.0
            else:
                return 0.5
        
        # Create policy with direct scorer
        policy = InterestingnessGuidedPolicy(
            concept_selection=ConceptSelectionStrategy.INTERESTINGNESS,
            action_selection=ActionSelectionStrategy.SIMULATE_AND_SCORE,
            interestingness_scorer=test_scorer
        )
        
        # Test scoring
        concept_score = policy.interestingness_scorer(self.concept_id, self.graph)
        conjecture_score = policy.interestingness_scorer(self.conjecture_id, self.graph)
        
        self.assertEqual(concept_score, 1.0)
        self.assertEqual(conjecture_score, 0.5)
    
    def test_missing_function_name(self):
        """Test behavior when function has wrong name."""
        # Create a temporary file with a wrongly named function
        with tempfile.NamedTemporaryFile(suffix='.py', mode='w', delete=False) as temp_file:
            temp_file.write("""
def wrong_function_name(entity_id, graph):
    return 1.0
""")
            temp_file_path = temp_file.name
        
        try:
            # Create policy with file loading - should raise AttributeError
            with self.assertRaises(AttributeError) as context:
                policy = InterestingnessGuidedPolicy(
                    concept_selection=ConceptSelectionStrategy.INTERESTINGNESS,
                    action_selection=ActionSelectionStrategy.SIMULATE_AND_SCORE,
                    interestingness_function_path=temp_file_path
                )
            
            # Verify the error message mentions the missing function
            self.assertIn("calculate_interestingness", str(context.exception))
            self.assertIn("wrong_function_name", str(context.exception))
            
        finally:
            # Clean up temporary file
            os.unlink(temp_file_path)
    
    def test_nonexistent_file(self):
        """Test behavior with nonexistent file."""
        # Use a path that doesn't exist
        nonexistent_path = "/path/to/nonexistent/file.py"
        
        # Create policy with nonexistent file - should raise FileNotFoundError
        with self.assertRaises(FileNotFoundError) as context:
            policy = InterestingnessGuidedPolicy(
                concept_selection=ConceptSelectionStrategy.INTERESTINGNESS,
                action_selection=ActionSelectionStrategy.SIMULATE_AND_SCORE,
                interestingness_function_path=nonexistent_path
            )
        
        # Verify the error message contains the nonexistent path
        self.assertIn(nonexistent_path, str(context.exception))
        
    def test_import_complex_test_program(self):
        """Test importing and using the complex interestingness function."""
        # Path to the test program using relative path
        complex_test_path = os.path.join(self.test_programs_dir, "complex_test.py")
        
        # Ensure the file exists
        self.assertTrue(os.path.exists(complex_test_path), 
                       f"Complex test program file does not exist at {complex_test_path}")
        
        # Create policy with file loading
        policy = InterestingnessGuidedPolicy(
            concept_selection=ConceptSelectionStrategy.INTERESTINGNESS,
            action_selection=ActionSelectionStrategy.SIMULATE_AND_SCORE,
            interestingness_function_path=complex_test_path
        )
        
        # Test scoring - check that we get values (without checking specific values)
        try:
            concept_score = policy.interestingness_scorer(self.concept_id, self.graph)
            conjecture_score = policy.interestingness_scorer(self.conjecture_id, self.graph)
            
            # Log the scores for debugging
            print(f"complex_test.py - Concept score: {concept_score}")
            print(f"complex_test.py - Conjecture score: {conjecture_score}")
            
            # Just verify that the function was loaded and returns something
            self.assertIsNotNone(concept_score)
            self.assertIsNotNone(conjecture_score)
            
            # Note: the complex test_program.py might return 0.0 due to dependency errors
            # so we don't assert it's non-zero
        except Exception as e:
            import traceback
            print(f"Error when running complex_test.py: {e}")
            print(traceback.format_exc())
            raise
            
    def test_import_simple_test_program(self):
        """Test importing and using the simple interestingness function without DSL dependencies."""
        # Path to the simple test program using relative path
        simple_test_path = os.path.join(self.test_programs_dir, "simple_test.py")
        
        # Ensure the file exists
        self.assertTrue(os.path.exists(simple_test_path), 
                       f"Simple test program file does not exist at {simple_test_path}")
        
        # Create policy with file loading
        policy = InterestingnessGuidedPolicy(
            concept_selection=ConceptSelectionStrategy.INTERESTINGNESS,
            action_selection=ActionSelectionStrategy.SIMULATE_AND_SCORE,
            interestingness_function_path=simple_test_path
        )
        
        # Test scoring - should get the simple scoring logic results
        concept_score = policy.interestingness_scorer(self.concept_id, self.graph)
        conjecture_score = policy.interestingness_scorer(self.conjecture_id, self.graph)
        
        # Log the scores for debugging
        print(f"simple_test.py - Concept score: {concept_score}")
        print(f"simple_test.py - Conjecture score: {conjecture_score}")
        
        # Verify expected values from the simple scoring function
        self.assertEqual(concept_score, 0.75)
        self.assertEqual(conjecture_score, 0.45)

if __name__ == "__main__":
    unittest.main() 