"""
This module implements the Specialize production rule for specializing predicates by fixing arguments.
"""

from typing import List, Optional, Tuple, Any, Dict, Type, Union
from frame.productions.base import ProductionRule
from frame.knowledge_base.entities import (
    ConceptApplication,
    Entity,
    Concept,
    ExampleStructure,
    ConceptType,
    Equals,
    Nat,
)

from frame.tools.z3_template import Z3Template, _format_args

class SpecializeRule(ProductionRule):
    """
    Production rule that takes a predicate or function concept and produces a new concept by fixing one of its arguments to a specific value.

    For example:
    - Given proper_divisors_count(n, k) which checks if k is the number of proper divisors of n
    - Applying SpecializeRule with index=1 and value=one_concept
    - Produces specialized_proper_divisors_count_at_1_to_one(n) which checks if n has exactly one proper divisor

    Or for functions:
    - Given tau(n, m) which returns the number of divisors of n in m
    - Applying SpecializeRule with index=1 and value=two_concept
    - Produces specialized_tau_at_1_to_two(n) which returns whether n has exactly 2 divisors (is prime)

    Special cases:
    1. Function outputs:
    - When index_to_specialize equals the function's input_arity, we're specializing the output
    - For example, specializing tau(n) with index=1 and two_concept creates a predicate is_prime(n) that checks if tau(n) == 2
    - This produces an Equals(function, value) concept

    2. Value concepts:
    - The value concept can be either:
      a) A constant concept (preferred)
      b) A function or predicate whose output type matches the specialized argument
    """

    def __init__(self, verbose: bool = False):
        super().__init__(
            name="specialize",
            description="Creates a new concept by fixing one argument of a predicate or function to a specific value, or by specializing a function's output",
            type="Concept",
        )
        self.verbose = verbose

    def get_input_types(
        self,
    ) -> List[Tuple[Type[Entity], Optional[Union[ConceptType, List[ConceptType]]]]]:
        """
        Get the types of inputs this rule expects.

        Returns:
            List[Tuple[Type[Entity], Optional[Union[ConceptType, List[ConceptType]]]]]:
                A list containing tuples for:
                1. The concept to specialize (function or predicate)
                2. The value concept (constant, function, or predicate)
        """
        return [
            (
                Concept,
                [ConceptType.FUNCTION, ConceptType.PREDICATE],
            ),  # First input: function or predicate
            (
                Concept,
                [ConceptType.CONSTANT, ConceptType.FUNCTION],
            ),  # Second input: constant preferred
        ]

    def get_valid_parameterizations(self, *inputs: Entity) -> List[Dict[str, Any]]:
        """
        Get valid parameterizations for specializing the given concepts.

        Args:
            *inputs: Two concepts - the concept to specialize and the value concept

        Returns:
            List[Dict[str, Any]]: List of valid parameter dictionaries, each containing:
            - index_to_specialize: The index of the argument to specialize
        """
        if len(inputs) != 2 or not all(isinstance(x, Concept) for x in inputs):
            return []

        concept, value_concept = inputs
        concept_type = concept.examples.example_structure.concept_type
        input_arity = concept.get_input_arity()

        if input_arity == 0:
            return []
        
        if concept_type == ConceptType.FUNCTION and input_arity == 0:
            return []
        # Get value type based on concept type
        value_type = value_concept.examples.example_structure.concept_type
        if value_type == ConceptType.CONSTANT:
            value_output_type = value_concept.get_component_types()[0]
        elif value_type == ConceptType.PREDICATE:
            return []
        else:
            # For functions, use their output type
            value_output_type = value_concept.get_component_types()[-1]

        valid_parameterizations = []

        # Check each possible index for specialization
        for index in range(input_arity):
            # Check type compatibility for input specialization
            concept_type_at_index = concept.get_component_types()[index]
            if concept_type_at_index == value_output_type:
                valid_parameterizations.append({"index_to_specialize": index})

        # Check for output specialization (only for functions)
        if concept_type == ConceptType.FUNCTION:
            # Check type compatibility for output specialization
            function_output_type = concept.get_component_types()[-1]

            # Check if the function has more than one output variable
            output_count = len(concept.get_component_types()) - input_arity

            if function_output_type == value_output_type and output_count <= 1:
                valid_parameterizations.append({"index_to_specialize": input_arity})

        return valid_parameterizations

    def can_apply(
        self, *inputs: Entity, index_to_specialize: int, verbose: bool = True
    ) -> bool:
        """
        Check if this rule can be applied to the inputs.

        Requirements:
        1. Two input concepts:
           - First must be function or predicate
           - Second can be constant, function of arity 0
        2. For input specialization: index must be < input_arity
        3. For output specialization: index must equal input_arity and concept must be a function
        4. Types must match at specialization point
        """
        if verbose:
            print("\nChecking if SpecializeRule can be applied:")

        if len(inputs) != 2 or not all(isinstance(x, Concept) for x in inputs):
            if verbose:
                print("❌ Failed: Must have exactly two inputs of type Concept")
            return False

        concept, value_concept = inputs
        if verbose:
            print(f"✓ First input is a concept: {concept.name}")
            print(f"✓ Second input is a concept: {value_concept.name}")

        # Get concept types
        concept_type = concept.examples.example_structure.concept_type
        value_type = value_concept.examples.example_structure.concept_type

        # Check that first concept is function or predicate
        if concept_type not in [ConceptType.FUNCTION, ConceptType.PREDICATE]:
            if verbose:
                print("❌ Failed: First concept must be a function or predicate")
            return False

        if concept_type == ConceptType.FUNCTION:
            if concept.get_input_arity() == 0:
                if verbose:
                    print("❌ Failed: Function must have at least one input")
                return False

        # Check that value concept is valid type
        if value_type not in [
            ConceptType.CONSTANT,
            ConceptType.FUNCTION,
        ]:
            if verbose:
                print(
                    "❌ Failed: Value concept must be a constant, function, or predicate"
                )
            return False

        if value_concept.get_input_arity() != 0:
            if verbose:
                print("❌ Failed: Value concept must have arity 0")
            return False

        if not (
            value_concept.has_computational_implementation()
            or len(value_concept.examples.get_examples()) > 0
        ):
            if verbose:
                print(
                    "❌ Failed: Value concept must have a computational implementation or examples"
                )
            return False

        input_arity = concept.examples.example_structure.input_arity

        # Check if we're specializing the output (index equals input_arity)
        is_output_specialization = index_to_specialize == input_arity

        # Handle output specialization case
        if is_output_specialization:
            if concept_type != ConceptType.FUNCTION:
                if verbose:
                    print(
                        "❌ Failed: For output specialization, concept must be a function"
                    )
                return False

            # Check if the function has more than one output variable
            output_count = len(concept.get_component_types()) - input_arity
            if output_count > 1:
                if verbose:
                    print(
                        "❌ Failed: Cannot specialize output of a function with multiple outputs"
                    )
                return False

            # Check type compatibility between function output and value
            function_output_type = concept.get_component_types()[-1]
            value_output_type = value_concept.get_component_types()[0]

            if function_output_type != value_output_type:
                if verbose:
                    print("❌ Failed: Type mismatch between function output and value")
                return False

            if verbose:
                print("✓ Valid function output to specialize")
            return True

        # Handle input specialization case
        # Check minimum arity
        min_arity = 2 if concept_type == ConceptType.PREDICATE else 1
        if input_arity < min_arity:
            if verbose:
                print(f"❌ Failed: Concept must take at least {min_arity} arguments")
            return False

        # Validate specialization index
        if not (0 <= index_to_specialize < input_arity):
            if verbose:
                print("❌ Failed: Invalid specialization index")
            return False

        # Check type compatibility
        concept_type_at_index = concept.get_component_types()[index_to_specialize]
        value_output_type = value_concept.get_component_types()[0]

        if concept_type_at_index != value_output_type:
            if verbose:
                print("❌ Failed: Type mismatch between concept argument and value")
            return False

        if verbose:
            print("✓ Valid concept to specialize")
        return True

    def determine_verification_capabilities(
        self, *inputs: Entity, index_to_specialize: int
    ) -> Tuple[bool, bool]:
        """
        Determine verification capabilities for specialization.

        For specialization, verification capabilities depend on the input concept and value concept:
        - For output specialization (checking if output equals a value):
          We need the original function to reliably compute outputs.
        - For input specialization:
          We need the original concept to be able to verify examples/nonexamples.

        Returns:
            Tuple[bool, bool]: (can_add_examples, can_add_nonexamples)
        """
        if not inputs or len(inputs) < 2:
            return True, True

        concept, value_concept = inputs

        # Get capabilities of base concept
        concept_can_add_examples = concept.can_add_examples
        concept_can_add_nonexamples = concept.can_add_nonexamples

        # Get capabilities of value concept
        value_can_add_examples = value_concept.can_add_examples

        # Get input arity to determine if this is output specialization
        input_arity = concept.examples.example_structure.input_arity
        is_output_specialization = index_to_specialize == input_arity

        # For output specialization, we need both the function and value concept to be reliable
        if is_output_specialization:
            return (
                concept_can_add_examples and value_can_add_examples,
                concept_can_add_nonexamples and value_can_add_examples,
            )

        # For input specialization, inherit capabilities from the base concept
        return concept_can_add_examples, concept_can_add_nonexamples

    def apply(
        self, *inputs: Entity, index_to_specialize: int, verbose: bool = False
    ) -> Entity:
        """
        Apply the specialization rule to create a new concept.

        Args:
            *inputs: The concept to specialize and the value concept
            index_to_specialize: Index of argument to fix
            verbose: Whether to print debug information

        Returns:
            Entity: The specialized concept
        """
        if not self.can_apply(
            *inputs, index_to_specialize=index_to_specialize, verbose=verbose
        ):
            raise ValueError("Cannot apply Specialization to these inputs")

        concept, value_concept = inputs
        value_name = value_concept.name
        fixed_value = self._get_fixed_value(value_concept)

        # Get verification capabilities
        can_add_examples, can_add_nonexamples = (
            self.determine_verification_capabilities(
                *inputs, index_to_specialize=index_to_specialize
            )
        )

        # Handle output specialization for functions
        if index_to_specialize == concept.get_input_arity():
            new_concept = self._create_output_specialized_concept(
                concept, 
                fixed_value, 
                value_name, 
                value_concept, 
                can_add_examples, 
                can_add_nonexamples
            )
            new_concept.map_iterate_depth = concept.map_iterate_depth
        else:
            new_concept = self._create_input_specialized_concept(
                concept,
                index_to_specialize,
                fixed_value,
                value_name,
                value_concept,
                can_add_examples,
                can_add_nonexamples,
            )
            new_concept.map_iterate_depth = concept.map_iterate_depth
        # Transform examples
        self._transform_examples(new_concept, concept, index_to_specialize, fixed_value)
        return new_concept

    def _get_fixed_value(self, value_concept: Entity) -> Any:
        """Extract the fixed value from a value concept."""
        try:
            if (
                value_concept.examples.example_structure.concept_type
                == ConceptType.CONSTANT
            ):
                try:
                    return value_concept.compute()
                except (AttributeError, TypeError):
                    example = next(iter(value_concept.examples.get_examples()))
                    return (
                        example.value[0]
                        if isinstance(example.value, tuple)
                        else example.value
                    )
            else:
                try:
                    return value_concept.compute()
                except (AttributeError, TypeError):
                    example = next(iter(value_concept.examples.get_examples()))
                    return example.value[-1]
        except (StopIteration, AttributeError, TypeError) as e:
            raise ValueError(
                f"Could not determine fixed value from value concept {value_concept.name}. "
                "Make sure it either has a compute() method or examples."
            )

    def _create_output_specialized_concept(
        self,
        concept: Entity,
        fixed_value: Any,
        value_name: str,
        value_concept: Entity,
        can_add_examples: bool,
        can_add_nonexamples: bool,
    ) -> Entity:
        """Create a new concept that checks if the output equals the fixed value."""

        def _compute_output_specialized(concept, *args):
            return concept.compute(*args) == fixed_value

        def _z3_output_specialized(*args):
            return NotImplementedError("to_z3 is not yet implemented")

        return Concept(
            name=f"specialized_({concept.name}_output_eq_{value_name})",
            description=f"Specialization of {concept.name} checking if output equals {fixed_value}",
            symbolic_definition=lambda *args: Equals(
                ConceptApplication(concept, *args), Nat(fixed_value)
            ),
            computational_implementation=lambda *args: (
                _compute_output_specialized(concept, *args)
                if concept.has_computational_implementation()
                else None
            ),
            example_structure=ExampleStructure(
                concept_type=ConceptType.PREDICATE,
                component_types=concept.get_component_types()[:-1],
                input_arity=concept.get_input_arity(),
            ),
            can_add_examples=can_add_examples,
            can_add_nonexamples=can_add_nonexamples,
            z3_translation=(lambda *args: self._z3_translate_output_specialized(concept, fixed_value, *args))
            if concept.has_z3_translation()
            else None,
        )
    
    def _z3_translate_output_specialized(self, concept: Concept, fixed_value: Any, *args) -> Z3Template:
        concept_template = concept.to_z3(*([None] * concept.get_input_arity()))
        program = concept_template.program
        template = Z3Template(
            code = f"""
            params {program.params};
            bounded params 0;
            f_0 := Func(
            {program.dsl()}
            );
            ReturnExpr None;
            ReturnPred f_0({_format_args([f"x_{i}" for i in range(len(args))])}) == {fixed_value};
        """)
        template.set_args(*args)
        return template

    def _create_input_specialized_concept(
        self,
        concept: Entity,
        index_to_specialize: int,
        fixed_value: Any,
        value_name: str,
        value_concept: Entity,
        can_add_examples: bool,
        can_add_nonexamples: bool,
    ) -> Entity:
        """Create a new concept with one input specialized to a fixed value."""
        concept_type = concept.examples.example_structure.concept_type

        # Create new type tuple without the specialized argument
        component_types = list(concept.get_component_types())
        component_types.pop(index_to_specialize)
        new_arity = (
            len(component_types) - 1
            if concept_type == ConceptType.FUNCTION
            else len(component_types)
        )

        # Special case: zero-arity predicate
        if concept_type == ConceptType.PREDICATE and new_arity == 0:
            component_types = tuple()
        else:
            component_types = tuple(component_types)

        def specialized_input_symbolic(*args):
            """Create symbolic definition with fixed value"""
            symbolic_args = list(args)
            # We assume the fixed value needs wrapping for the symbolic representation
            symbolic_args.insert(index_to_specialize, Nat(fixed_value))
            return ConceptApplication(concept, *symbolic_args)

        def specialized_input_compute(*args):
            """Insert fixed value at specialized index"""
            full_args = list(args)
            full_args.insert(index_to_specialize, fixed_value)
            return concept.compute(*full_args)


        return Concept(
            name=f"specialized_({concept.name}_at_{index_to_specialize}_to_{value_name})",
            description=f"Specialization of {concept.name} with argument {index_to_specialize} fixed to {fixed_value}",
            symbolic_definition=specialized_input_symbolic,
            computational_implementation=(
                specialized_input_compute
                if concept.has_computational_implementation()
                else None
            ),
            example_structure=ExampleStructure(
                concept_type=concept_type,
                component_types=component_types,
                input_arity=new_arity,
            ),
            can_add_examples=can_add_examples,
            can_add_nonexamples=can_add_nonexamples,
            z3_translation=(lambda *args: self._z3_translate_input_specialized_predicate(concept, index_to_specialize, fixed_value, *args))
            if concept_type == ConceptType.PREDICATE and concept.has_z3_translation()
            else (lambda *args: self._z3_translate_input_specialized_function(concept, index_to_specialize, fixed_value, *args))
            if concept.has_z3_translation()
            else None,
        )
    
    def _z3_translate_input_specialized_predicate(
            self, 
            concept: Concept, 
            index_to_specialize: int, 
            fixed_value: Any, 
            *args) -> Z3Template:
        concept_template = concept.to_z3(*([None] * concept.get_input_arity()))
        program = concept_template.program
        full_args = [f"x_{i}" for i in range(len(args))]
        full_args.insert(index_to_specialize, fixed_value)

        code = f"""
        params {program.params - 1}; 
        bounded params 0;
        """ 

        code += f"""
        p_0 := Pred(
            {program.dsl()}
        );
        """

        code += f"""
        ReturnExpr None; 
        ReturnPred p_0({_format_args(full_args)});
        """
        template = Z3Template(code)
        template.set_args(*args)
        return template
    
    def _z3_translate_input_specialized_function(
            self, 
            concept: Concept, 
            index_to_specialize: int, 
            fixed_value: Any, 
            *args) -> Z3Template:
        concept_template = concept.to_z3(*([None] * concept.get_input_arity()))
        program = concept_template.program
        full_args = [f"x_{i}" for i in range(len(args))]
        full_args.insert(index_to_specialize, fixed_value)

        code = f"""
        params {program.params - 1}; 
        bounded params 0;
        """ 

        code += f"""
        f_0 := Func(
            {program.dsl()}
        );
        """

        code += f"""
        ReturnExpr f_0({_format_args(full_args)});
        ReturnPred None;
        """
        template = Z3Template(code)
        template.set_args(*args)
        return template

    def _transform_examples(
        self, new_concept: Entity, concept: Entity, index_to_specialize: int, value: Any
    ):
        """Transform examples from the base concept to the specialized concept.

        For each example in the base concept:
        1. Check if the specialized argument matches the fixed value
        2. If it matches, create a new example without that argument
        3. If it doesn't match, create a nonexample (for input specialization)
        4. Add the new example/nonexample to the specialized concept
        """

        # Transform positive examples
        for ex in concept.examples.get_examples():
            if not isinstance(ex.value, tuple):
                continue

            # Create a new example with the specialized argument removed
            new_args = list(ex.value)

            # Check if this example has the right value at the specified index
            if ex.value[index_to_specialize] == value:
                # Remove the specialized argument
                new_args.pop(index_to_specialize)

                # Add the new example to the new concept
                try:
                    if new_args:  # If there are still arguments left
                        new_concept.add_example(tuple(new_args))
                    else:  # If no arguments left, add a simple True example
                        new_concept.add_example((True,))
                except ValueError:
                    continue
            elif index_to_specialize == len(concept.examples.example_structure.component_types) - 1 and concept.is_function() and ex.value[index_to_specialize] != value:
                # if the index to specialize was the last argument of the input function then we can add nonexamples
                new_args = list(ex.value)
                new_args.pop(index_to_specialize)
                new_concept.add_nonexample(tuple(new_args))
            else:
                pass
        # Transform negative examples (if available)
        if hasattr(concept.examples, "get_nonexamples"):
            for ex in concept.examples.get_nonexamples():
                if not isinstance(ex.value, tuple):
                    continue

                # Create a new example with the specialized argument removed
                new_args = list(ex.value)

                # Check if this example has the right value at the specified index
                if ex.value[index_to_specialize] == value:
                    # Remove the specialized argument
                    new_args.pop(index_to_specialize)

                    # Add the new nonexample to the new concept
                    try:
                        if new_args:  # If there are still arguments left
                            new_concept.add_nonexample(tuple(new_args))
                        else:  # If no arguments left, add a simple False example
                            new_concept.add_nonexample((False,))
                    except ValueError:
                        continue