"""
This module implements the Implication production rule for generating conjectures about concept implications.
"""

from typing import List, Optional, Set, Tuple, Any, Dict, Type
from frame.productions.base import ProductionRule
from frame.knowledge_base.entities import (
    Expression, Var, ConceptApplication, Forall, NatDomain,
    Entity, Concept, Conjecture, Example, ExampleType, ExampleStructure, ConceptType,
    Equals, Implies, And, Exists, Nat, Fold, Lambda
)
import logging
from frame.tools.z3_template import Z3Template, _format_args

class ImplicationRule(ProductionRule):
    """
    Production rule that takes two predicate concepts and produces a implication conjecture.
    
    For predicates P, Q: A₁ × ... × Aₙ → Bool, produces:
    ∀x₁...xₙ. P(x₁,...,xₙ) → Q(x₁,...,xₙ)
    
    Example:
    - Given is_even_square(n) and is_divisible_by_4(n)
    - Produces conjecture that every even square number is divisible by 4
    """
    
    def __init__(self):
        super().__init__(
            name="implication",
            description="Creates a conjecture stating that one predicate concept implies another",
            type="Conjecture"
        )
    
    def determine_verification_capabilities(self, *inputs: Entity) -> Tuple[bool, bool]:
        """
        Determine verification capabilities for implication conjectures.
        
        For implication conjectures (P → Q), verification capabilities depend on the constituent concepts:
        - For examples: We can verify an example of implication (P → Q is true) if we can verify:
          * P is false (which makes P → Q trivially true), or
          * Both P and Q are true (which makes P → Q true)
        - For nonexamples: We can verify a case where P → Q is false only if we can verify
          that P is true and Q is false.
        
        Returns:
            Tuple[bool, bool]: (can_add_examples, can_add_nonexamples)
        """
        if len(inputs) != 2:
            return False, False
            
        concept1, concept2 = inputs  # P → Q
        
        # For implication examples, we need to verify either:
        # - P is false (which needs concept1's nonexample verification), or
        # - Both P and Q are true (which needs both concepts' example verification)
        concept1_can_add_examples = concept1.can_add_examples
        concept2_can_add_examples = concept2.can_add_examples
        concept1_can_add_nonexamples = concept1.can_add_nonexamples
        
        can_verify_examples = concept1_can_add_nonexamples or (concept1_can_add_examples and concept2_can_add_examples)
        
        # For implication nonexamples (P → Q is false), we need to verify both:
        # - P is true (which needs concept1's example verification), and
        # - Q is false (which needs concept2's nonexample verification)
        concept2_can_add_nonexamples = concept2.can_add_nonexamples
        
        can_verify_nonexamples = concept1_can_add_examples and concept2_can_add_nonexamples
        
        return can_verify_examples, can_verify_nonexamples
    
    def get_input_types(self) -> List[List[Tuple[Type, Any]]]:
        """
        Returns the valid input types for this production rule.
        
        Returns:
            A list of lists, where each inner list represents a valid combination of input types.
            Each element in the inner list is a tuple of (type, subtype).
        """
        return [
            [(Concept, ConceptType.PREDICATE), (Concept, ConceptType.PREDICATE)]
        ]
    
    def get_valid_parameterizations(self, *inputs: Entity) -> List[Dict[str, Any]]:
        """
        Get valid parameterizations for creating an implication conjecture.
        
        Args:
            *inputs: Two predicate concepts to create an implication between
            
        Returns:
            List of valid parameter dictionaries (empty dict for this rule as it requires no additional parameters)
        """
        # Check if the rule can be applied to these inputs
        if not self.can_apply(*inputs, verbose=False):
            return []
            
        # Implication rule doesn't require any additional parameters
        return [{}]
    
    def can_apply(self, *inputs: Entity, verbose: bool = True) -> bool:
        """
        Check if this rule can be applied to the inputs.
        Also checks that the equivalence rule cannot be applied.
        
        Requirements:
        1. Two input concepts
        2. Both must be predicates
        3. Must have same arity and matching types
        4. First concept's examples must not contradict second concept
        5. Must have non-trivial examples (P true → Q true)
        6. Must have examples showing non-equivalence (Q true, P false)
        7. The two concepts must be different (not the same concept)
        
        Args:
            *inputs: Input entities to check
            verbose: Whether to print debug information
        """
        if verbose:
            print("Checking requirements:")
        
        if len(inputs) != 2 or not all(isinstance(x, Concept) for x in inputs):
            if verbose:
                print("❌ Failed: Must have exactly two inputs of type Concept")
            return False
            
        concept1, concept2 = inputs  # P → Q
        
        # Check that the concepts are not the same
        if concept1.name == concept2.name:
            if verbose:
                print("❌ Failed: Cannot create implication between the same concept")
            return False
            
        if verbose:
            print(f"✓ Inputs are concepts: {concept1.name}, {concept2.name}")
        
        # Check that both are predicates
        if (concept1.examples.example_structure.concept_type != ConceptType.PREDICATE or
            concept2.examples.example_structure.concept_type != ConceptType.PREDICATE):
            if verbose:
                print("❌ Failed: Both concepts must be predicates")
            return False
            
        # Check arities match
        if (concept1.examples.example_structure.input_arity != 
            concept2.examples.example_structure.input_arity):
            if verbose:
                print("❌ Failed: Concepts must have the same arity")
            return False
            
        # Check types match
        if concept1.examples.example_structure.component_types != concept2.examples.example_structure.component_types:
            if verbose:
                print("❌ Failed: Concepts must have matching types")
            return False
        
        # Get examples and nonexamples
        examples1 = set(ex.value[:concept1.examples.example_structure.input_arity] 
                       for ex in concept1.examples.get_examples())
        examples2 = set(ex.value[:concept2.examples.example_structure.input_arity] 
                       for ex in concept2.examples.get_examples())
        nonexamples2 = set(ex.value[:concept2.examples.example_structure.input_arity] 
                          for ex in concept2.examples.get_nonexamples())
        
        # Check that P → Q is satisfied (no contradictions)
        if examples1.intersection(nonexamples2):
            if verbose:
                print("❌ Failed: First concept's examples contradict second concept")
            return False
        
        # Check for non-trivial implication (must have P true → Q true cases)
        if not examples1.intersection(examples2):
            if verbose:
                print("❌ Failed: No examples where both concepts are true")
                logging.warning(f"Implication from {concept1.name} to {concept2.name} would be vacuously true")
            return False
        
        # # Check for non-equivalence (must have Q true, P false cases)
        # nonexamples1 = set(ex.value[:concept1.examples.example_structure.input_arity] 
        #                   for ex in concept1.examples.get_nonexamples())
        # if not examples2.intersection(nonexamples1):
        #     if verbose:
        #         print("❌ Failed: No examples showing non-equivalence")
        #         logging.warning(
        #         f"No examples found where {concept2.name} is true but {concept1.name} is false. "
        #         f"Consider using EquivalenceRule instead for {concept1.name} and {concept2.name}."
        #         )
        #     return False
                
        if verbose:
            print("✓ Valid concepts for implication")
        return True

    def apply(self, *inputs: Entity) -> Entity:
        """
        Apply the implication rule to create a new conjecture.
        
        For predicates P, Q: A₁ × ... × Aₙ → Bool:
        ∀x₁...xₙ. P(x₁,...,xₙ) → Q(x₁,...,xₙ)
        """
        if not self.can_apply(*inputs, verbose=False):
            raise ValueError("Cannot apply Implication to these inputs")
            
        concept1, concept2 = inputs  # P → Q
        arity = concept1.examples.example_structure.input_arity
        
        # Determine verification capabilities
        can_add_examples, can_add_nonexamples = self.determine_verification_capabilities(*inputs)
        
        # Create variables for quantification
        vars = [Var(f"x{i}") for i in range(arity)]
        
        # Build the implication expression
        impl_expr = Implies(
            ConceptApplication(concept1, *vars),
            ConceptApplication(concept2, *vars)
        )
        
        # Wrap with universal quantifiers
        expr = impl_expr
        for i in reversed(range(arity)):
            expr = Forall(f"x{i}", NatDomain(), expr)
        
        # Create the conjecture with the same example structure as the input concepts
        conjecture = Conjecture(
            name=f"implies_({concept1.name}_{concept2.name})",
            description=f"Conjecture that {concept1.name} implies {concept2.name}",
            symbolic_definition=lambda: expr,
            example_structure=concept1.examples.example_structure,
            can_add_examples=can_add_examples,
            can_add_nonexamples=can_add_nonexamples,
            z3_translation=(lambda *args: self._z3_translate_implication(concept1, concept2, *args))
            if concept1.has_z3_translation() and concept2.has_z3_translation()
            else None
        )
        conjecture.map_iterate_depth = max(concept1.map_iterate_depth, concept2.map_iterate_depth)
        return conjecture

    def _z3_translate_implication(self, concept1: Concept, concept2: Concept, *args) -> Z3Template:
        """Translate an implication P -> Q to a Z3 program."""
        try:
            arity = concept1.get_input_arity() # Arity should be the same for both
            
            template1 = concept1.to_z3(*([None] * arity))
            program1 = template1.program
            template2 = concept2.to_z3(*([None] * arity))
            program2 = template2.program

            quantified_args = "[" + ", ".join([f"b_{i}" for i in range(arity)]) + "]"
            params = [f"b_{i}" for i in range(arity)]
            args_string = _format_args(params)

            template = Z3Template(
                code=f"""
                params 0;
                bounded params {arity};
                p_0 := Pred({program1.dsl()});
                p_1 := Pred({program2.dsl()});
                ReturnExpr None;
                ReturnPred ForAll({quantified_args}, Implies(p_0({args_string}), p_1({args_string})));
                """
            )
            template.set_args(*args) # Pass any top-level args if needed (usually none for implication)
            return template
        except Exception as e:
            print(f"Error translating implication {concept1.name} -> {concept2.name}: {e}")
            # Potentially re-raise or return a specific error template
            raise # Re-raise for now to see the error during testing

# ==============================
# TEST CONCEPTS AND FUNCTIONS
# ==============================

def create_divisible_by_4_concept():
    """Helper to create a concept for numbers divisible by 4"""
    from frame.knowledge_base.demonstrations import divides
    
    concept = Concept(
        name="is_divisible_by_4",
        description="Tests if a number is divisible by 4",
        symbolic_definition=lambda n: ConceptApplication(divides, Nat(4), n),
        computational_implementation=lambda n: n % 4 == 0,
        example_structure=ExampleStructure(
            concept_type=ConceptType.PREDICATE,
            component_types=(ExampleType.NUMERIC,),
            input_arity=1
        )
    )
    
    # Add examples
    concept.add_example((4,))
    concept.add_example((8,))
    concept.add_example((12,))
    concept.add_example((16,))
    concept.add_example((20,))
    
    # Add nonexamples
    concept.add_nonexample((1,))
    concept.add_nonexample((2,))
    concept.add_nonexample((3,))
    concept.add_nonexample((6,))
    concept.add_nonexample((10,))
    
    return concept

def create_even_square_concept():
    """Helper to create a concept for even square numbers"""
    from frame.knowledge_base.demonstrations import is_even, multiplication
    
    # First create a square concept
    square = Concept(
        name="is_square",
        description="Tests if a number is a perfect square",
        symbolic_definition=lambda n: 
            Exists("k", NatDomain(),
                Equals(n, ConceptApplication(multiplication, Var("k"), Var("k")))
            ),
        computational_implementation=lambda n: 
            int(n ** 0.5) ** 2 == n,
        example_structure=ExampleStructure(
            concept_type=ConceptType.PREDICATE,
            component_types=(ExampleType.NUMERIC,),
            input_arity=1
        )
    )
    
    # Add examples for square
    square.add_example((0,))
    square.add_example((1,))
    square.add_example((4,))
    square.add_example((9,))
    square.add_example((16,))
    square.add_example((25,))
    
    # Add nonexamples for square
    square.add_nonexample((2,))
    square.add_nonexample((3,))
    square.add_nonexample((5,))
    square.add_nonexample((6,))
    square.add_nonexample((7,))
    square.add_nonexample((8,))
    
    # Now create even square concept
    concept = Concept(
        name="is_even_square",
        description="Tests if a number is both even and a perfect square",
        symbolic_definition=lambda n: And(
            ConceptApplication(is_even, n),
            ConceptApplication(square, n)
        ),
        computational_implementation=lambda n:
            n % 2 == 0 and int(n ** 0.5) ** 2 == n,
        example_structure=ExampleStructure(
            concept_type=ConceptType.PREDICATE,
            component_types=(ExampleType.NUMERIC,),
            input_arity=1
        )
    )
    
    # Add examples
    concept.add_example((0,))
    concept.add_example((4,))
    concept.add_example((16,))
    concept.add_example((36,))
    concept.add_example((64,))
    
    # Add nonexamples
    concept.add_nonexample((1,))
    concept.add_nonexample((2,))
    concept.add_nonexample((3,))
    concept.add_nonexample((6,))
    concept.add_nonexample((8,))
    concept.add_nonexample((9,))
    concept.add_nonexample((10,))
    concept.add_nonexample((12,))
    concept.add_nonexample((15,))
    
    return concept

def test_implication():
    """Test the Implication production rule"""
    print("\n=== Testing Implication Production Rule ===")
    
    # Create concepts
    even_square = create_even_square_concept()
    div_by_4 = create_divisible_by_4_concept()
    
    # Create and apply Implication
    implication = ImplicationRule()
    
    # Test get_input_types
    print("\nTesting get_input_types:")
    input_types = implication.get_input_types()
    print(f"Valid input types: {input_types}")
    
    # Test get_valid_parameterizations
    print("\nTesting get_valid_parameterizations:")
    valid_params = implication.get_valid_parameterizations(even_square, div_by_4)
    print(f"Valid parameterizations: {valid_params}")
    
    print("\nChecking if rule can be applied to even_square → div_by_4...")
    can_apply = implication.can_apply(even_square, div_by_4)
    print(f"Can apply: {can_apply}")
    
    if can_apply:
        print("\nApplying rule to create conjecture...")
        conjecture = implication.apply(even_square, div_by_4)
        
        # Show symbolic form
        print("\nSymbolic form of conjecture:")
        symbolic_expr = conjecture.symbolic()
        
        
        return conjecture
    return None

def test_implication_z3():
    """Test the Implication production rule with Z3"""
    print("\n=== Testing Implication Production Rule with Z3 ===")
    
    # Create concepts

    is_even = Concept(
        name="is_even",
        description="Tests if a number is even",
        symbolic_definition=lambda n: n % 2 == 0,
        computational_implementation=lambda n: n % 2 == 0,
        example_structure=ExampleStructure(
            concept_type=ConceptType.PREDICATE,
            component_types=(ExampleType.NUMERIC,),
            input_arity=1
        ),
        z3_translation=lambda n: Z3Template(
            f"""
            params 1;
            bounded params 0;
            ReturnExpr None;
            ReturnPred x_0 % 2 == 0;
            """,
            n
        )
    )
    # add examples
    is_even.add_example((2,))
    is_even.add_example((4,))
    is_even.add_nonexample((1,))
    is_even.add_nonexample((3,))

    is_divisible_by_4 = Concept(
        name="is_divisible_by_4",
        description="Tests if a number is divisible by 4",
        symbolic_definition=lambda n: n % 4 == 0,
        computational_implementation=lambda n: n % 4 == 0,
        example_structure=ExampleStructure(
            concept_type=ConceptType.PREDICATE,
            component_types=(ExampleType.NUMERIC,),
            input_arity=1
        ),
        z3_translation=lambda n: Z3Template(
            f"""
            params 1;
            bounded params 0;
            ReturnExpr None;
            ReturnPred x_0 % 4 == 0;
            """,
            n
        )
    )
    # add examples
    is_divisible_by_4.add_example((4,))
    is_divisible_by_4.add_example((8,))
    is_divisible_by_4.add_nonexample((1,))
    # is_divisible_by_4.add_nonexample((2,))
    is_divisible_by_4.add_nonexample((3,))
    
    # create implication
    incorrect_implication = ImplicationRule()
    incorrect_conjecture = incorrect_implication.apply(is_even, is_divisible_by_4)
    template = incorrect_conjecture.to_z3()
    result = template.run()
    print(f"Z3 Result Proved: {result.proved}")
    assert not result.proved, "Z3 should falsify that is_even implies is_divisible_by_4"
    print("✓ Z3 Test Passed")

    # create incorrect implication
    correct_implication = ImplicationRule()
    correct_conjecture = correct_implication.apply(is_divisible_by_4, is_even)
    template = correct_conjecture.to_z3()
    result = template.run()
    print(f"Z3 Result Proved: {result.proved}")
    assert result.proved, "Z3 should prove that is_divisible_by_4 implies is_even"
    print("✓ Z3 Test Passed")


if __name__ == "__main__":
    test_implication() 
    test_implication_z3()