"""
This module implements the Exists production rule for factoring out arguments via existential quantification.
"""

from typing import List, Tuple, Any, Dict, Type
from frame.productions.base import ProductionRule
import itertools
import random
import time
import logging

from frame.knowledge_base.entities import (
    Expression,
    Exists,
    Var,
    ConceptApplication,
    NatDomain,
    And,
    Entity,
    Concept,
    ExampleType,
    ExampleStructure,
    ConceptType,
    Equals,
)
from frame.knowledge_base.demonstrations import (
    divides,
    is_even,
    addition,
    multiplication,
)
from frame.tools.z3_template import Z3Template, _format_args

# Maximum limit for searching witnesses in existential quantification
# TODO(_; 3/25): Extract this to a configuration file or make it configurable per rule
MAX_SEARCH_LIMIT = 40 # Maximum number of combinations to try
MAX_VALUE_PER_VAR = 10   # Maximum value for each variable
DEFAULT_INTERNAL_EXISTS_SEARCH_TIMEOUT = 0.1 # Default timeout for the internal search loop in seconds

logger = logging.getLogger(__name__)

class ExistsRule(ProductionRule):
    """
    Production rule that takes a concept and produces a new concept by existentially quantifying over some of its arguments.

    For example:
    - Given divides_and_even(a,b) = divides(b,a) ∧ is_even(b)
    - Produces exists_divides_and_even(a) = ∃b. divides(b,a) ∧ is_even(b)

    The rule:
    1. Takes a concept and indices of arguments to quantify over
    2. Creates new symbolic definition with existential quantifiers
    3. Creates computational implementation that searches for witnesses
    4. Transforms examples based on existence of witnesses
    """

    def __init__(self, internal_search_timeout: float = DEFAULT_INTERNAL_EXISTS_SEARCH_TIMEOUT):
        super().__init__(
            name="exists",
            description="Creates a new concept by existentially quantifying over selected arguments",
            type="Concept",
        )
        self.internal_search_timeout = internal_search_timeout
        logger.info(f"ExistsRule initialized with internal_search_timeout: {self.internal_search_timeout}s")

    def determine_verification_capabilities(self, *inputs: Entity) -> Tuple[bool, bool]:
        """
        Existential quantification can reliably verify positive examples (by finding a witness)
        but cannot reliably verify negative examples (would require checking all possible values).
        
        This capability is further restricted by the input concept's capabilities:
        - If the input concept cannot reliably verify examples, the result also cannot
        - If the input concept cannot reliably verify non-examples, it doesn't affect 
          this rule's inherent inability to verify non-examples
        
        Returns:
            Tuple[bool, bool]: (can_add_examples, can_add_nonexamples)
        """
        if not inputs:
            return True, False
        
        # The inherent capabilities of existential quantification
        inherent_can_add_examples = True
        inherent_can_add_nonexamples = False
        
        # Get the capabilities of the input concept
        input_concept = inputs[0]
        input_can_add_examples = input_concept.can_add_examples
        input_can_add_nonexamples = input_concept.can_add_nonexamples
        
        # For existential quantification:
        # - We can add examples only if the input concept can reliably verify examples
        # - We inherently cannot add non-examples (even if input can)
        can_add_examples = inherent_can_add_examples and input_can_add_examples
        can_add_nonexamples = inherent_can_add_nonexamples  # Always False
        
        return can_add_examples, can_add_nonexamples

    def get_input_types(self) -> List[Tuple[Type, List[ConceptType]]]:
        """Return the expected input types for this rule.

        The exists rule accepts a single concept that must be either a predicate or a function.
        """
        return [(Concept, [ConceptType.PREDICATE, ConceptType.FUNCTION])]

    def get_valid_parameterizations(self, *inputs: Any) -> List[Dict[str, Any]]:
        """Return valid parameterizations for existentially quantifying the given concept.

        Args:
            *inputs: A single concept to apply the exists rule to.

        Returns:
            A list of dictionaries, where each dictionary contains a key 'indices_to_quantify'
            with a list of valid indices that can be quantified over.
        """
        if len(inputs) != 1:
            return []

        concept = inputs[0]
        if not isinstance(concept, Concept):
            return []

        if concept.examples.example_structure.concept_type not in [
            ConceptType.PREDICATE,
            ConceptType.FUNCTION,
        ]:
            return []

        # Get the input arity of the concept
        input_arity = concept.get_input_arity()

        # Generate all possible combinations of indices to quantify
        # We start from combinations of size 1 (quantify over one index)
        valid_parameterizations = []

        # Limit the number of parameterizations to avoid combinatorial explosion
        max_parameterizations = 100

        # Generate all possible combinations of indices to quantify
        for size in range(1, input_arity + 1):
            for indices in itertools.combinations(range(input_arity), size):
                valid_parameterizations.append({"indices_to_quantify": list(indices)})

                if len(valid_parameterizations) >= max_parameterizations:
                    return valid_parameterizations

        return valid_parameterizations

    def can_apply(
        self, *inputs: Entity, indices_to_quantify: List[int], verbose: bool = True
    ) -> bool:
        """
        Check if this rule can be applied to the inputs.

        Requirements:
        1. Single input concept
        2. For predicates: Input concept must take multiple arguments
        3. For functions: Input concept must take at least one argument
        4. Indices to quantify must be valid for the concept's arity
        5. At least one argument must remain unquantified for predicates

        Args:
            *inputs: Input entities to check
            indices_to_quantify: List of indices to quantify
            verbose: Whether to print debug information
        """
        if verbose:
            print("Checking requirements:")
            print(f"Number of inputs: {len(inputs)}")
            if len(inputs) > 0:
                print(f"First input type: {type(inputs[0])}")
                print(f"First input module: {type(inputs[0]).__module__}")
                print(f"Is instance check: {isinstance(inputs[0], Concept)}")
                print(f"Concept class module: {Concept.__module__}")

        if len(inputs) != 1 or not isinstance(inputs[0], Concept):
            if verbose:
                print("❌ Failed: Must have exactly one input of type Concept")
            return False

        concept = inputs[0]
        if verbose:
            print(f"✓ Input is a concept: {concept.name}")

        try:
            # Get input arity from example structure
            arity = concept.get_input_arity()
            is_function = (
                concept.examples.example_structure.concept_type == ConceptType.FUNCTION
            )

            if is_function:
                if arity < 1:
                    if verbose:
                        print("❌ Failed: Function must take at least one argument")
                    return False
            else:  # Predicate case
                if arity <= 1:
                    if verbose:
                        print("❌ Failed: Predicate must take multiple arguments")
                    return False

            if verbose:
                print(f"✓ Concept takes {arity} arguments")

            # Validate indices to quantify
            if not indices_to_quantify:
                if verbose:
                    print("❌ Failed: Must specify at least one index to quantify")
                return False

            if is_function and max(indices_to_quantify) > arity:
                if verbose:
                    print("❌ Failed: Indices to quantify exceed concept arity")
                return False

            if not is_function and len(indices_to_quantify) >= arity:
                if verbose:
                    print(
                        "❌ Failed: Cannot quantify over all arguments for predicates"
                    )
                return False

            # Add check in can_apply to prevent quantifying over outputs
            if is_function:
                num_inputs = concept.get_input_arity()
                if any(i >= num_inputs for i in indices_to_quantify):
                    if verbose:
                        print("❌ Failed: Cannot quantify over function outputs")
                    return False

            if verbose:
                print(f"✓ Valid indices to quantify: {indices_to_quantify}")
            return True

        except AttributeError as e:
            if verbose:
                print(f"❌ Failed: {str(e)}")
            return False

    def apply(self, *inputs: Entity, indices_to_quantify: List[int]) -> Entity:
        """
        Apply the exists rule to create a new concept.

        For predicates P(x₁,...,xₙ), creates a new predicate that existentially quantifies
        over the specified arguments.

        For functions f(x₁,...,xₙ) = (y₁,...,yₘ), creates a predicate P(...,y₁,...,yₘ) that checks
        if there exist values for the quantified arguments such that f(...) = (y₁,...,yₘ).
        The arity depends on how many arguments are quantified:
        - Quantify some inputs: P(...,y₁,...,yₘ) with arity = n - |quantified| + m
        - Quantify all inputs: P(y₁,...,yₘ) with arity = m
        """
        if not self.can_apply(*inputs, indices_to_quantify=indices_to_quantify):
            raise ValueError("Cannot apply Exists to these inputs")

        concept = inputs[0]
        is_function = (
            concept.examples.example_structure.concept_type == ConceptType.FUNCTION
        )

        # Create new concept with reduced arity
        component_types = concept.examples.example_structure.component_types
        print(f"\nExistsRule.apply - Initial component types: {component_types}")

        if is_function:
            # For functions, keep unquantified inputs and all outputs
            num_inputs = concept.get_input_arity()
            num_outputs = len(component_types) - num_inputs

            # For functions, keep unquantified inputs and all outputs
            input_types = tuple(
                t
                for i, t in enumerate(component_types[:num_inputs])
                if i not in indices_to_quantify
            )
            output_types = component_types[-num_outputs:]  # All output types
            new_types = input_types + output_types
            input_arity = len(input_types) + len(
                output_types
            )  # Both unquantified inputs and outputs become inputs
        else:
            # For predicates, just remove quantified indices
            new_types = tuple(
                t for i, t in enumerate(component_types) if i not in indices_to_quantify
            )
            input_arity = len(new_types)  # For predicates, all types are inputs

        print(f"Types after removing quantified indices: {new_types}")

        # Always create a predicate concept
        concept_type = ConceptType.PREDICATE

        print(f"Final concept type: {concept_type}")
        print(f"Final input arity: {input_arity}")

        # Get verification capabilities from the rule
        can_add_examples, can_add_nonexamples = self.determine_verification_capabilities(*inputs)

        new_concept = Concept(
            name=f"exists_({concept.name}_indices_{indices_to_quantify})",
            description=f"Existential quantification over {concept.name}",
            symbolic_definition=lambda *args: self._build_exists_expr(
                concept, args, indices_to_quantify, is_function
            ),
            computational_implementation=self._build_computational_impl(
                concept, indices_to_quantify, is_function
            ),
            example_structure=ExampleStructure(
                concept_type=concept_type,
                component_types=new_types,
                input_arity=input_arity,
            ),
            can_add_examples=can_add_examples,
            can_add_nonexamples=can_add_nonexamples,
            z3_translation=(lambda *args: self._z3_translate_exists_predicate(concept, indices_to_quantify, *args)) 
                if (not is_function) and concept.has_z3_translation()
                else (lambda *args: self._z3_translate_exists_function(concept, indices_to_quantify, *args))
                if (is_function) and concept.has_z3_translation()
                else None
        )

        # Store metadata for use in transform_examples
        new_concept._indices_to_quantify = indices_to_quantify
        new_concept._is_function = is_function
        new_concept.map_iterate_depth = concept.map_iterate_depth
        
        # Transform examples
        self._transform_examples(new_concept, concept)

        return new_concept
    
    def _z3_translate_exists_predicate(self, concept, indices, *args):
        template = concept.to_z3(*([None] * concept.get_input_arity()))
        program = template.program
    
        code = f"""
        params {program.params - len(indices)};
        bounded params {len(indices)};
        p_0 := Pred(
            {program.dsl()}
        );
        """

        existential_quantifier = "[" + ", ".join([f"b_{i}" for i in range(len(indices))]) + "]"

        params = [None] * program.params
    
        x_params_start = 0 
        b_params_start = 0
        for i in range(program.params):
            if i not in indices: 
                params[i] = f"x_{x_params_start}"
                x_params_start += 1
            else:
                params[i] = f"b_{b_params_start}"
                b_params_start += 1

        args_string = _format_args(params)
        code += f"""
        ReturnExpr None;
        ReturnPred Exists({existential_quantifier}, p_0({args_string}));
        """

        template = Z3Template(code)
        template.set_args(*args)
        return template

    def _z3_translate_exists_function(self, concept, indices, *args): 
        template = concept.to_z3(*([None] * concept.get_input_arity()))
        program = template.program
    
        code = f"""
        params {program.params - len(indices) + 1};
        bounded params {len(indices)};
        f_0 := Func(
            {program.dsl()}
        );
        """

        existential_quantifier = "[" + ", ".join([f"b_{i}" for i in range(len(indices))]) + "]"

        params = [None] * (program.params + 1)
    
        x_params_start = 0 
        b_params_start = 0
        for i in range(program.params + 1):
            if i not in indices: 
                params[i] = f"x_{x_params_start}"
                x_params_start += 1
            else:
                params[i] = f"b_{b_params_start}"
                b_params_start += 1
        
        
        args_string = _format_args(params[:-1])
        
        code += f"""
        ReturnExpr None;
        ReturnPred Exists({existential_quantifier}, f_0({args_string}) == {params[-1]});
        """

        template = Z3Template(code)
        template.set_args(*args)
        return template

    def _transform_examples(self, new_concept: Entity, concept: Entity):
        """Transform examples from the base concept into examples for the existential concept.

        For functions f(x1,...,xn) = (y1,...,ym):
        - When quantifying some inputs: Examples become (unquantified inputs, y1,...,ym)
        - When quantifying all inputs: Examples become (y1,...,ym)

        For predicates P(x1,...,xn):
        - Examples become tuples of unquantified inputs

        Note: We do not transform nonexamples since their meaning under existential
        quantification is not straightforward.
        """
        indices_to_quantify = new_concept._indices_to_quantify
        is_function = new_concept._is_function
        print(f"\n_transform_examples - Indices to quantify: {indices_to_quantify}")
        print(f"Is function: {is_function}")

        for ex in concept.examples.get_examples():
            if not isinstance(ex.value, tuple):
                continue

            if is_function:
                # For functions, keep unquantified inputs and all outputs
                num_inputs = concept.get_input_arity()
                inputs = ex.value[:num_inputs]  # All inputs
                outputs = ex.value[num_inputs:]  # All outputs (could be multiple)

                # Keep only unquantified inputs
                unquantified_inputs = tuple(
                    val for i, val in enumerate(inputs) if i not in indices_to_quantify
                )

                # Combine with all outputs
                new_value = unquantified_inputs + outputs
            else:
                # For predicates, just keep unquantified inputs
                new_value = tuple(
                    val
                    for i, val in enumerate(ex.value)
                    if i not in indices_to_quantify
                )

            try:
                new_concept.add_example(new_value)
                # print(f"Successfully added example: {new_value}")
            except Exception as e:
                print(f"Failed to add example: {str(e)}")

    def _build_exists_expr(
        self,
        concept: Concept,
        args: Tuple[Any, ...],
        indices_to_quantify: List[int],
        is_function: bool = False,
    ) -> Expression:
        """Build symbolic expression with existential quantifiers.

        For functions f(x1,...,xn) = (y1,...,ym), we have:
        1. Quantify over some inputs: Creates predicate P(...,y1,...,ym) = ∃xi. f(...,xi,...) = (y1,...,ym)

        For predicates P(x1,...,xn), we have:
        1. Quantify over some inputs: Creates predicate Q(...) = ∃xi. P(...,xi,...)
        """
        if is_function:
            # For functions, build exists k1,k2,.... f(k1,k2,...) = (y1,...,ym)
            # First create the inner equality expression
            inner_args = []
            arg_idx = 0
            num_inputs = concept.get_input_arity()
            num_outputs = len(concept.get_component_types()) - num_inputs

            # Build argument list, using Vars for quantified indices
            for i in range(num_inputs):
                if i in indices_to_quantify:
                    inner_args.append(Var(f"k{i}"))
                else:
                    inner_args.append(args[arg_idx])
                    arg_idx += 1

            # Get output arguments (last num_outputs arguments from args)
            outputs = args[-num_outputs:]

            # Build equality: f(k1,k2,...) = (y1,...,ym)
            inner_expr = Equals(
                ConceptApplication(concept, *inner_args),
                outputs[0] if num_outputs == 1 else outputs,
            )

            # Wrap with existential quantifiers
            expr = inner_expr
            for i in reversed(indices_to_quantify):
                expr = Exists(f"k{i}", NatDomain(), expr)
            return expr

        else:
            # For predicates, use original logic
            arg_map = {}
            arg_idx = 0
            for i in range(concept.get_input_arity()):
                if i not in indices_to_quantify:
                    arg_map[i] = args[arg_idx]
                    arg_idx += 1

            inner_args = []
            for i in range(concept.get_input_arity()):
                if i in indices_to_quantify:
                    inner_args.append(Var(f"y{i}"))
                else:
                    inner_args.append(arg_map[i])
            inner_expr = ConceptApplication(concept, *inner_args)

            expr = inner_expr
            for i in reversed(indices_to_quantify):
                expr = Exists(f"y{i}", NatDomain(), expr)
            return expr

    def _build_computational_impl(
        self,
        concept: Concept,
        indices_to_quantify: List[int],
        is_function: bool = False,
    ) -> callable:
        """
        Build computational implementation for the new concept.
        
        This implementation tries MAX_SEARCH_LIMIT randomly chosen
        combinations of variable values, where each variable ranges from 0 to 
        MAX_VALUE_PER_VAR (inclusive). It also respects an internal timeout.
        
        Args:
            concept: The input concept
            indices_to_quantify: Indices of arguments to quantify over
            is_function: Whether the input concept is a function
            
        Returns:
            A callable implementation for the new concept
        """
        num_vars = len(indices_to_quantify)
        
        if is_function:
            num_inputs = concept.examples.example_structure.input_arity
            output_arity = len(concept.examples.example_structure.component_types) - num_inputs
            exists_arity = num_inputs + output_arity - num_vars
            
            def impl(*args):
                if len(args) != exists_arity:
                    raise ValueError(f"Expected {exists_arity} arguments, got {len(args)}")

                # Split args into regular inputs and target outputs
                remaining_inputs = len(args) - output_arity
                regular_inputs = args[:remaining_inputs]
                target_outputs = args[remaining_inputs:]
                
                # Prepare arguments for compute call, with placeholders for quantified vars
                full_args = [None] * num_inputs
                arg_idx = 0
                for i in range(num_inputs):
                    if i not in indices_to_quantify:
                        full_args[i] = regular_inputs[arg_idx]
                        arg_idx += 1
                
                # Use random sampling directly
                start_time = time.time() # For internal timeout
                for i in range(MAX_SEARCH_LIMIT):
                    # Check for internal timeout
                    if time.time() - start_time > self.internal_search_timeout:
                        logger.warning(f"ExistsRule computational_impl (function) internal search timed out after {self.internal_search_timeout}s for {concept.name}")
                        break # Exit search loop

                    # Generate a random combination of values for quantified variables
                    values = tuple(random.randint(0, MAX_VALUE_PER_VAR) for _ in range(num_vars))
                    
                    # Assign values to quantified variables
                    for idx_pos, idx in enumerate(indices_to_quantify):
                        full_args[idx] = values[idx_pos]
                    
                    # Check if this combination satisfies the target outputs
                    try:
                        result = concept.compute(*full_args)
                        
                        # Handle single output (convert to tuple for uniform comparison)
                        if not isinstance(result, tuple):
                            result = (result,)
                        
                        # Convert target_outputs to tuple if it's a single value
                        target = target_outputs if len(target_outputs) > 1 else (target_outputs[0],)
                        
                        if result == target:
                            return True
                    except Exception:
                        # Skip combinations that cause errors
                        continue
                
                return False

            return impl
        else:
            # For predicates
            def impl(*args):
                # Prepare arguments for compute call, with placeholders for quantified vars
                full_args = [None] * concept.examples.example_structure.input_arity
                arg_idx = 0
                for i in range(concept.examples.example_structure.input_arity):
                    if i not in indices_to_quantify:
                        full_args[i] = args[arg_idx]
                        arg_idx += 1
                
                # Use random sampling directly
                start_time = time.time() # For internal timeout
                for i in range(MAX_SEARCH_LIMIT):
                    # Check for internal timeout
                    if time.time() - start_time > self.internal_search_timeout:
                        logger.warning(f"ExistsRule computational_impl (predicate) internal search timed out after {self.internal_search_timeout}s for {concept.name}")
                        break # Exit search loop
                        
                    # Generate a random combination of values for quantified variables
                    values = tuple(random.randint(0, MAX_VALUE_PER_VAR) for _ in range(num_vars))
                    
                    # Assign values to quantified variables
                    for idx_pos, idx in enumerate(indices_to_quantify):
                        full_args[idx] = values[idx_pos]
                    
                    # Check if this combination satisfies the predicate
                    try:
                        if concept.compute(*full_args):
                            return True
                    except Exception:
                        # Skip combinations that cause errors
                        continue
                
                return False
            return impl