"""
This module implements the Match production rule for specializing concepts by making multiple arguments equal.

For example:
- Given multiply(a,b) = a * b
- Applying MatchRule with indices [0,1] produces square(n) = n * n
"""

from typing import List, Optional, Tuple, Any, Dict, Type, Union
from itertools import combinations
from frame.productions.base import ProductionRule
from frame.knowledge_base.entities import (
    ConceptApplication,
    Entity,
    Concept,
    ExampleStructure,
    ConceptType,
)

from frame.tools.z3_template import Z3Template, _format_args

class MatchRule(ProductionRule):
    """
    Production rule that takes a concept and produces a new concept by making multiple arguments take the same value.

    For example:
    - Given multiply(a,b) = a * b
    - Produces square(n) = multiply(n,n)

    The rule:
    1. Takes a concept and indices of arguments to match
    2. Creates new symbolic definition where specified arguments are the same
    3. Creates computational implementation that passes the same value to matched arguments
    4. Transforms examples by filtering for cases where specified arguments are equal
    """

    def __init__(self):
        super().__init__(
            name="match",
            description="Creates a new concept by making multiple arguments take the same value",
            type="Concept",
        )

    def determine_verification_capabilities(self, *inputs: Entity) -> Tuple[bool, bool]:
        """
        Matching arguments preserves the verification capabilities of the input concept.
        
        Since the match rule simply restricts the domain by requiring certain arguments to be equal,
        the resulting concept inherits the verification capabilities of the original concept.
        
        Returns:
            Tuple[bool, bool]: (can_add_examples, can_add_nonexamples)
        """
        if not inputs:
            return True, True
        
        # Get capabilities of the input concept
        input_concept = inputs[0]
        input_can_add_examples = input_concept.can_add_examples
        input_can_add_nonexamples = input_concept.can_add_nonexamples
        
        # Preserve the input concept's capabilities
        return input_can_add_examples, input_can_add_nonexamples

    def get_input_types(
        self,
    ) -> List[Tuple[Type[Entity], Optional[Union[ConceptType, List[ConceptType]]]]]:
        """
        Get the types of inputs this rule expects.

        Returns:
            List[Tuple[Type[Entity], Optional[Union[ConceptType, List[ConceptType]]]]]:
                A list containing a single tuple for the concept (function or predicate)
        """
        return [(Concept, [ConceptType.FUNCTION, ConceptType.PREDICATE])]

    def get_valid_parameterizations(self, *inputs: Entity) -> List[Dict[str, Any]]:
        """
        Get valid parameterizations for matching arguments in the given concept.

        Args:
            *inputs: A single concept (function or predicate) with multiple arguments

        Returns:
            List of valid parameter dictionaries, each containing:
            - indices_to_match: List of indices to make equal

        Note:
            If there are more than 100 valid parameterizations, returns the 100 with
            the smallest number of indices to match (more specific matches).
        """
        # Check if the inputs match the expected input types
        if not self.check_input_types(*inputs):
            return []

        concept = inputs[0]

        try:
            # Get input arity from example structure
            arity = concept.get_input_arity()
            if arity <= 1:
                return []  # Concept must take multiple arguments

            # Get component types
            component_types = concept.get_component_types()

            # Generate all possible combinations of indices to match
            valid_parameterizations = []

            # For each possible size of indices to match (from 2 to arity)
            for size in range(2, arity + 1):
                # Generate all combinations of that size
                for indices in combinations(range(arity), size):
                    indices = list(indices)

                    # Check that all indices have the same type
                    base_type = component_types[indices[0]]
                    if all(component_types[i] == base_type for i in indices):
                        valid_parameterizations.append({"indices_to_match": indices})

            # Limit the number of parameterizations to 100
            if len(valid_parameterizations) > 100:
                # Sort by the size of indices_to_match (smaller is better - more specific matches)
                valid_parameterizations.sort(key=lambda p: len(p["indices_to_match"]))
                valid_parameterizations = valid_parameterizations[:100]

            return valid_parameterizations

        except AttributeError:
            return []  # Concept doesn't have the required attributes

    def can_apply(
        self, *inputs: Entity, indices_to_match: List[int], verbose: bool = True
    ) -> bool:
        """
        Check if this rule can be applied to the inputs.

        Requirements:
        1. Single input concept
        2. Input concept must take multiple arguments
        3. Indices to match must be valid for the concept's arity
        4. At least two indices must be specified to match

        Args:
            *inputs: Input entities to check
            indices_to_match: List of indices to match
            verbose: Whether to print debug information
        """
        if verbose:
            print("Checking requirements:")

        if len(inputs) != 1 or not isinstance(inputs[0], Concept):
            if verbose:
                print("❌ Failed: Must have exactly one input of type Concept")
            return False

        concept = inputs[0]
        if verbose:
            print(f"✓ Input is a concept: {concept.name}")

        try:
            # Get input arity from example structure
            arity = concept.get_input_arity()
            if arity <= 1:
                if verbose:
                    print("❌ Failed: Concept must take multiple arguments")
                return False

            if verbose:
                print(f"✓ Concept takes {arity} arguments")

            # Validate indices to match
            if len(indices_to_match) < 2:
                if verbose:
                    print("❌ Failed: Must specify at least two indices to match")
                return False

            if max(indices_to_match) >= arity:
                if verbose:
                    print("❌ Failed: Indices to match exceed concept arity")
                return False

            # Check that all indices have the same type
            component_types = concept.get_component_types()
            base_type = component_types[indices_to_match[0]]
            if not all(component_types[i] == base_type for i in indices_to_match):
                if verbose:
                    print("❌ Failed: All matched arguments must have the same type")
                return False

            if verbose:
                print(f"✓ Valid indices to match: {indices_to_match}")
            return True

        except AttributeError as e:
            if verbose:
                print(f"❌ Failed: {str(e)}")
            return False

    def apply(self, *inputs: Entity, indices_to_match: List[int]) -> Entity:
        """
        Apply the match rule to create a new concept.

        For n-ary concept C(x1,...,xn), creates new concept where specified indices
        take the same value. For example, C(x,y,z) with indices [0,2] becomes C'(x,y)
        where C'(a,b) = C(a,b,a).
        """
        if not self.can_apply(*inputs, indices_to_match=indices_to_match):
            raise ValueError("Cannot apply Match to these inputs")

        concept = inputs[0]
        input_arity = concept.get_input_arity()
        component_types = concept.get_component_types()

        # Create mapping from old indices to new indices
        # Indices being matched will map to the same new index
        old_to_new = {}
        new_index = 0
        for i in range(input_arity):  # Only map input indices
            if i in indices_to_match:
                if not any(i2 in old_to_new for i2 in indices_to_match):
                    # First matched index gets a new index
                    old_to_new[i] = new_index
                    new_index += 1
                else:
                    # Other matched indices use the same new index
                    old_to_new[i] = old_to_new[indices_to_match[0]]
            else:
                old_to_new[i] = new_index
                new_index += 1

        def matched_compute(*args):
            """Helper function to compute with matched arguments"""
            # Expand args back to original arity
            expanded_args = []
            for i in range(input_arity):  # Only expand input indices
                expanded_args.append(args[old_to_new[i]])
            return concept.compute(*expanded_args)
        
        def _z3_translate_match(*args):
            """Helper function to convert matched arguments to Z3"""
            # Expand args back to original arity
            expanded_args = []
            for i in range(input_arity):  # Only expand input indices
                expanded_args.append(f"x_{old_to_new[i]}")

            concept_template = concept.to_z3(*([None] * concept.get_input_arity()))
            program = concept_template.program

            code = f"""
            params {program.params - len(indices_to_match) + 1}; 
            bounded params 0;
            """ 
            
            if concept.examples.example_structure.concept_type == ConceptType.PREDICATE:
                code += f"""
                p_0 := Pred(
                    {program.dsl()}
                );
                """

                code += f"""
                ReturnExpr None; 
                ReturnPred p_0({_format_args(expanded_args)});
                """
            else:
                code += f"""
                f_0 := Func(
                    {program.dsl()}
                );
                """
                code += f"""
                ReturnExpr f_0({_format_args(expanded_args)});
                ReturnPred None;
                """

            template = Z3Template(code)
            template.set_args(*args)
            return template

        # Create new type with reduced arity
        new_types = []
        seen_indices = set()
        for i in range(input_arity):  # Only consider input indices
            if old_to_new[i] not in seen_indices:
                new_types.append(component_types[i])
                seen_indices.add(old_to_new[i])

        # Add output type
        if concept.examples.example_structure.concept_type != ConceptType.PREDICATE:
            new_types.append(component_types[-1])
            new_input_arity = len(new_types) - 1
        else:
            new_input_arity = len(new_types)

        new_types = tuple(new_types)

        # Get verification capabilities from the original concept
        can_add_examples, can_add_nonexamples = self.determine_verification_capabilities(*inputs)
        
        # Create the new concept
        new_concept = Concept(
            name=f"matched_({concept.name}_indices_{indices_to_match})",
            description=f"Specialization of {concept.name} with arguments {indices_to_match} made equal",
            symbolic_definition=lambda *args: ConceptApplication(
                concept,
                *[
                    args[old_to_new[i]] for i in range(input_arity)
                ],  # Only use input indices
            ),
            computational_implementation=matched_compute
                if concept.has_computational_implementation()
                else None,
            example_structure=ExampleStructure(
                concept_type=concept.examples.example_structure.concept_type,
                component_types=new_types,
                input_arity=new_input_arity,
            ),
            can_add_examples=can_add_examples,
            can_add_nonexamples=can_add_nonexamples,
            z3_translation=(lambda *args: _z3_translate_match(*args))
                if concept.has_z3_translation()
                else None,
        )

        # Store indices_to_match for use in transform_examples
        new_concept._indices_to_match = indices_to_match
        new_concept._old_to_new = old_to_new
        new_concept.map_iterate_depth = concept.map_iterate_depth
        # Transform examples
        self._transform_examples(new_concept, concept)

        return new_concept

    def _transform_examples(self, new_concept: Entity, concept: Entity):
        """
        Transform examples from the base concept into examples for the specialized concept.
        Only keep examples where the specified indices have equal values.
        """
        indices_to_match = new_concept._indices_to_match
        old_to_new = new_concept._old_to_new
        input_arity = concept.get_input_arity()

        for ex in concept.examples.get_examples():
            if not isinstance(ex.value, tuple):
                continue

            # Check if values at indices to match are equal
            values = [ex.value[i] for i in indices_to_match]
            if not all(v == values[0] for v in values):
                continue

            # Create new example with reduced arity
            new_values = []
            seen_indices = set()
            for i in range(input_arity):  # Only consider input indices
                if old_to_new[i] not in seen_indices:
                    new_values.append(ex.value[i])
                    seen_indices.add(old_to_new[i])

            # Add output value
            if concept.examples.example_structure.concept_type != ConceptType.PREDICATE:
                new_values.append(ex.value[-1])

            try:
                new_concept.add_example(tuple(new_values))
            except ValueError:
                continue

        # Similarly transform nonexamples
        for ex in concept.examples.get_nonexamples():
            if not isinstance(ex.value, tuple):
                continue

            # For nonexamples, we only keep ones where the matched indices are equal
            # (since these prove the specialized concept can still be false)
            values = [ex.value[i] for i in indices_to_match]
            if not all(v == values[0] for v in values):
                continue

            new_values = []
            seen_indices = set()
            for i in range(input_arity):  # Only consider input indices
                if old_to_new[i] not in seen_indices:
                    new_values.append(ex.value[i])
                    seen_indices.add(old_to_new[i])
            # Add output value
            if concept.examples.example_structure.concept_type != ConceptType.PREDICATE:
                new_values.append(ex.value[-1])

            try:
                new_concept.add_nonexample(tuple(new_values))
            except ValueError:
                continue

    