"""
This module implements the Exclusivity (Applicability) production rule for generating conjectures about the
restriction of concept examples to a particular finite set.

For example:
- "2 is the only even prime number"
- Fermat's Last Theorem (a^n + b^n = c^n only has solutions for integers a, b, and c
when n = 1 or n = 2)
"""

from typing import List, Optional, Set, Tuple, Any, Dict, Type
from frame.productions.base import ProductionRule
from frame.knowledge_base.entities import (
    Expression, Var, ConceptApplication, Forall, NatDomain,
    Entity, Concept, Conjecture, Example, ExampleType, ExampleStructure, ConceptType,
    Equals, Implies, Not, And, Exists, Nat, Fold, Lambda,
    In, Set as SetExpr, TupleDomain, TupleExpr,
)
from frame.knowledge_base.demonstrations import (
    is_prime, is_even
)

class ExclusivityRule(ProductionRule):
    """
    Production rule that takes a concept and produces a conjecture stating that the concept is satisfied exactly on a particular finite set of inputs.
    
    For predicates P: A₁ × ... × Aₙ → Bool, produces:
    ∀x₁...xₙ. P(x₁,...,xₙ) ↔ x₁,...,xₙ ∈ S

    For functions f: A₁ × ... × Aₙ → B and value v, produces:
    ∀x₁...xₙ. f(x₁,...,xₙ) = v ↔ x₁,...,xₙ ∈ S

    Example:
    - Given is_even_prime(n) = is_even(n) ∧ is_prime(n)
    - Produces conjecture that ∀n. (is_even(n) ∧ is_prime(n)) ↔ n ∈ S = {2}
    """

    def __init__(self):
        super().__init__(
            name="exclusivity",
            description="Creates a conjecture stating that a particular finite set of inputs satisfies a concept",
            type="Conjecture"
        )

    def determine_verification_capabilities(self, *inputs: Entity, valid_set: Optional[Set] = None, target_value: Optional[Any] = None) -> Tuple[bool, bool]:
        """
        Determine verification capabilities for exclusivity conjectures.
        
        For exclusivity conjectures (P ↔ x ∈ S), verification capabilities depend on:
        - For examples: We need to verify that elements in the valid_set satisfy the predicate/function.
        - For nonexamples: We need to verify that elements outside the valid_set don't satisfy the predicate/function.
        
        Returns:
            Tuple[bool, bool]: (can_add_examples, can_add_nonexamples)
        """
        if not inputs or not isinstance(inputs[0], Concept) or valid_set is None:
            return False, False
            
        concept = inputs[0]
        
        # For both predicates and functions, we need the concept to have reliable verification
        # capabilities for both examples and nonexamples
        concept_can_add_examples = concept.can_add_examples
        concept_can_add_nonexamples = concept.can_add_nonexamples
        
        # Since the valid_set is explicitly provided and finite, we can always check
        # whether a given input is in the set. So our verification capabilities depend
        # entirely on the concept's capabilities.
        return concept_can_add_examples, concept_can_add_nonexamples

    def get_input_types(self) -> List[List[Tuple[Type, Any]]]:
        """
        Returns the valid input types for this production rule.
        
        Returns:
            A list of lists, where each inner list represents a valid combination of input types.
            Each element in the inner list is a tuple of (type, subtype).
        """
        return [
            [(Concept, ConceptType.PREDICATE), (SetExpr, None)],  # Predicate concept and valid set
            [(Concept, ConceptType.FUNCTION), (Concept, ConceptType.FUNCTION), (SetExpr, None)]  # Function, target value concept, and valid set
        ]
    
    def get_valid_parameterizations(self, *inputs: Entity) -> List[Dict[str, Any]]:
        """
        Get valid parameterizations for applying the Exclusivity rule.
        
        Args:
            *inputs: Either:
                    - A predicate concept and a set of valid tuples
                    - A function concept, target value concept, and a set of valid tuples
            
        Returns:
        """
        # Check if the rule can be applied to these inputs
        if not self.can_apply(*inputs, verbose=False):
            return [] # Invalid inputs
        
        concept = inputs[0]
        return [{}]
    
    def can_apply(self, *inputs: Entity, valid_set: Set[Tuple], target_value: Optional[Any] = None, 
                  verbose: bool = False) -> bool:
        """
        Check if this rule can be applied to the inputs.
        
        Requirements:
        1. Single input concept and valid_set
        2. Valid set must be non-empty and contain tuples of correct arity
        3. For predicates: at least one positive example
        4. For functions: at least one example mapping to target_value
        5. Must have some nonexamples to support conjecture

        """
        print("Checking requirements:")
        
        if len(inputs) != 1 or not isinstance(inputs[0], Concept) or not isinstance(valid_set, Set):
            print("❌ Failed: Must have exactly one input of type Concept and one input of type Set")
            return False
            
        concept = inputs[0]
        print(f"✓ Input is a concept: {concept.name} \n and a set: {valid_set}")

        # Check valid_set is non-empty and contains tuples of correct arity
        arity = concept.examples.example_structure.input_arity
        if not valid_set or not all(isinstance(x, tuple) and len(x) == arity for x in valid_set):
            print(f"❌ Failed: Valid set must contain tuples of arity {arity}")
            return False
        
        # Get examples and check if any contradict applicability
        examples = concept.examples.get_examples()
        
        if concept.examples.example_structure.concept_type == ConceptType.PREDICATE:
            if not examples:
                print("❌ Failed: Predicate has no positive examples")
                return False
        else:  # Function case
            if target_value is None:
                print("❌ Failed: Must specify target value for function")
                return False
                
            # Check if any examples map to target value
            if not any(ex.value[-1] == target_value for ex in examples):
                print("❌ Failed: Function has no examples mapping to target value")
                return False
        
        # Must have some nonexamples to support conjecture
        if not concept.examples.get_nonexamples():
            print("❌ Failed: Must have nonexamples to support exclusivity")
            return False
            
        print("✓ Valid concept for exclusivity conjecture")
        return True
    
    def apply(self, *inputs: Entity, valid_set: Optional[Set] = None, target_value: Optional[Any] = None, **params) -> Entity:
        """
        Apply the exclusivity rule to generate a conjecture.
        
        Args:
            *inputs: Either:
                    - A predicate concept and a set of tuples over which the concept is true
                    - A function concept, target value concept, and a set of tuples over which the function = target value
            valid_set: Set of tuples that satisfy the concept (can also be provided in params)
            target_value: Optional target value for function concepts (can also be provided in params)
            **params: Additional parameters including valid_set and target_value if not provided as kwargs
            
        Returns:
            A conjecture stating that the concept is satisfied exactly on the valid set
        """
        # Extract valid_set and target_value from either kwargs or params
        valid_set = valid_set if valid_set is not None else params.get('valid_set')
        target_value = target_value if target_value is not None else params.get('target_value')
        
        if not self.can_apply(*inputs, valid_set=valid_set, target_value=target_value):
            raise ValueError("Cannot apply Exclusivity rule with given inputs and parameters")
            
        concept = inputs[0]
        arity = concept.examples.example_structure.input_arity
        
        # Determine verification capabilities
        can_add_examples, can_add_nonexamples = self.determine_verification_capabilities(
            *inputs, valid_set=valid_set, target_value=target_value
        )
        
        # Create variables for quantification
        vars = [Var(f"x{i}") for i in range(arity)]

        # Convert Python set of tuples to Expression Set
        # First create the domain for tuples of the right arity
        tuple_domain = TupleDomain([NatDomain() for _ in range(arity)])

        # Convert each tuple in valid_set to a tuple of Nat expressions
        valid_tuples = [TupleExpr([Nat(x) for x in t]) for t in valid_set]
        valid_set_expr = SetExpr(tuple_domain, elements=valid_tuples)
        
        # Build the exclusivity expression
        if concept.examples.example_structure.concept_type == ConceptType.PREDICATE:
            # For predicates: P(x₁,...,xₙ) ↔ (x₁,...,xₙ) ∈ S
            exclusivity_expr = And(
                Implies( # (P → in_set)
                    ConceptApplication(concept, *vars), 
                    In(TupleExpr(vars), valid_set_expr),  # Use ExprTuple
                ),
                Implies( # (in_set → P)
                    In(TupleExpr(vars), valid_set_expr),  # Use ExprTuple
                    ConceptApplication(concept, *vars),
                ),
            )
            name_suffix = ""
        else:
            # For functions: f(x₁,...,xₙ) = v ↔ (x₁,...,xₙ) ∈ S
            exclusivity_expr = And(
                Implies(
                    Equals(ConceptApplication(concept, *vars), Nat(target_value)),
                    In(TupleExpr(vars), valid_set_expr),  # Use ExprTuple
                ),
                Implies(
                    In(TupleExpr(vars), valid_set_expr),  # Use ExprTuple
                    Equals(ConceptApplication(concept, *vars), Nat(target_value)),
                ),
            )
            name_suffix = f"_equals_{target_value}"

        
        # Note(_; 2/18): If we expand quantification to support many variables without nesting, modify here.
        # Wrap with universal quantifiers
        expr = exclusivity_expr
        for i in reversed(range(arity)):
            expr = Forall(f"x{i}", NatDomain(), expr)
        
        # Create the conjecture with the same example structure as the input concept
        conjecture = Conjecture(
            name=f"excl_({concept.name}{name_suffix}_{valid_set})",
            description=f"Conjecture that '{concept.name}' is satisfied exactly on the set {valid_set}",
            symbolic_definition=lambda: expr,
            example_structure=concept.examples.example_structure,
            can_add_examples=can_add_examples,
            can_add_nonexamples=can_add_nonexamples
        )
        conjecture.map_iterate_depth = concept.map_iterate_depth
        return conjecture
    
def create_even_prime_concept():
    """Helper to create a concept for even prime numbers"""
    # For now, create it directly for testing
    concept = Concept(
        name="is_even_prime",
        description="Tests if a number is both even and prime",
        symbolic_definition=lambda n: And(
            ConceptApplication(is_even, n),
            ConceptApplication(is_prime, n)
        ),
        computational_implementation=lambda n:
            is_even.compute(n) and is_prime.compute(n),
        example_structure=ExampleStructure(
            concept_type=ConceptType.PREDICATE,
            component_types=(ExampleType.NUMERIC,),
            input_arity=1
        )
    )
    
    # Add example
    concept.add_example((2,))  

    # Add nonexamples
    concept.add_nonexample((1,))  # Not even or prime
    concept.add_nonexample((3,))  # Not even
    concept.add_nonexample((6,))  # Not prime
    concept.add_nonexample((28,))  # Not prime
    
    return concept

def create_sum_equals_concept():
    """Helper to create a concept for pairs of numbers that sum to a target"""
    from frame.knowledge_base.demonstrations import addition

    concept = Concept(
        name="sum_equals_4",
        description="Tests if two numbers sum to 4",
        symbolic_definition=lambda x, y: Equals(
            ConceptApplication(addition, x, y), Nat(4)
        ),
        computational_implementation=lambda x, y: x + y == 4,
        example_structure=ExampleStructure(
            concept_type=ConceptType.PREDICATE,
            component_types=(ExampleType.NUMERIC, ExampleType.NUMERIC),
            input_arity=2,
        ),
    )

    # Add examples
    concept.add_example((0, 4))  # 0 + 4 = 4
    concept.add_example((1, 3))  # 1 + 3 = 4
    concept.add_example((2, 2))  # 2 + 2 = 4
    concept.add_example((3, 1))  # 3 + 1 = 4
    concept.add_example((4, 0))  # 4 + 0 = 4

    # Add nonexamples
    concept.add_nonexample((0, 0))  # 0 + 0 ≠ 4
    concept.add_nonexample((1, 1))  # 1 + 1 ≠ 4
    concept.add_nonexample((2, 3))  # 2 + 3 ≠ 4
    concept.add_nonexample((5, 0))  # 5 + 0 ≠ 4

    return concept

def test_exclusivity():
    """Test the Exclusivity production rule"""
    print("\n=== Testing Exclusivity Production Rule ===")
    
    # Test 1: The only even prime number is 2
    print("\n--- Testing Even Prime Numbers ---")
    even_prime = create_even_prime_concept()
    print("\nCreated even_prime concept with examples:")
    print(f"Examples: {[ex.value for ex in even_prime.examples.get_examples()]}")
    print(f"Nonexamples: {[ex.value for ex in even_prime.examples.get_nonexamples()]}")
    
    # Create and apply Exclusivity
    exclusivity = ExclusivityRule()
    print("\nChecking if rule can be applied to even_prime...")
    can_apply = exclusivity.can_apply(even_prime, valid_set={(2,)})
    print(f"Can apply: {can_apply}")
    
    if can_apply:
        print("\nApplying rule to create excl_even_prime_numbers conjecture...")
        excl_even_prime = exclusivity.apply(even_prime, valid_set={(2,)})
        
        # Show symbolic form
        print("\nSymbolic form of conjecture:")
        symbolic_expr = excl_even_prime.symbolic()
        

    # Test 2: Pairs of numbers that sum to 4
    print("\n--- Testing Sum Equals 4 (Binary Predicate) ---")
    sum_equals_4 = create_sum_equals_concept()
    print("\nCreated sum_equals_4 concept with examples:")
    print(f"Examples: {[ex.value for ex in sum_equals_4.examples.get_examples()]}")
    print(
        f"Nonexamples: {[ex.value for ex in sum_equals_4.examples.get_nonexamples()]}"
    )

    valid_pairs = {(0, 4), (1, 3), (2, 2), (3, 1), (4, 0)}
    print("\nChecking if rule can be applied to sum_equals_4...")
    can_apply = exclusivity.can_apply(sum_equals_4, valid_set=valid_pairs)
    print(f"Can apply: {can_apply}")

    if can_apply:
        print("\nApplying rule to create excl_sum_equals_4 conjecture...")
        excl_sum_equals_4 = exclusivity.apply(sum_equals_4, valid_set=valid_pairs)

        # Show symbolic form
        print("\nSymbolic form of conjecture:")
        symbolic_expr = excl_sum_equals_4.symbolic()

        return excl_even_prime, excl_sum_equals_4

    return None

if __name__ == "__main__":
    test_exclusivity() 