"""
Production rule that creates a new concept by imposing a forall quantifier.
"""

import itertools
from typing import Optional, Any, Tuple, List, Dict, Type, Union

import random
import math # Added import
import time # Add time import
import logging # Add logging import
from frame.productions.base import ProductionRule
from frame.knowledge_base.entities import (
    Expression,
    Var,
    ConceptApplication,
    Forall,
    Entity,
    Concept,
    Conjecture,
    ExampleType,
    ExampleStructure,
    ConceptType,
    Implies,
)

# Import default domains for each type of concept.
from frame.knowledge_base.entities import (
    NatDomain,
    GroupDomain,
    SetDomain,
    GroupElementDomain,
)

from frame.tools.z3_template import Z3Template, _format_args

# Maximum limit for searching counterexamples in universal quantification
# TODO(future): Extract this to a configuration file or make it configurable per rule
MAX_SEARCH_LIMIT = 100  # Maximum number of combinations to try
MAX_VALUE_PER_VAR = 5  # Maximum value for each variable
DEFAULT_INTERNAL_FORALL_SEARCH_TIMEOUT = 0.1 # Default timeout for the internal search loop in seconds

NUM_PARAMETERIZATIONS = 50 # Limit for sampled parameterizations

# TODO(_; 4/29): Make sure to add tests for the conjecture production path.
# TODO(_; 4/29): For the double predicate case, perhaps we should have only a single 
# predicate rule, and have a separate implication rule where we can use the match rule.

logger = logging.getLogger(__name__)

def _factorial(n):
    """Helper for factorial."""
    return math.factorial(n)

def _combinations(n, k):
    """Helper for combinations (n choose k). Handles k > n case."""
    if k < 0 or k > n:
        return 0
    return math.comb(n, k)

def _validate_mapping(
    indices_to_map: Dict[int, int], 
    primary_concept: Concept, 
    secondary_concept: Concept,
    verbose: bool = False
) -> bool:
    """
    Validate that a mapping for forall rule is valid.

    Args:
        output_to_input_map: Maps output indices of inner_concept to input indices of outer_concept
        primary_concept: The first concept in the forall rule
        secondary_concept: The second concept in the forall rule

    Returns:
        bool: True if the mapping is valid, False otherwise
        Note(_; 4/29): We currently enforce the mapping to be one-to-one.
    """
    primary_arity = primary_concept.get_input_arity()
    secondary_arity = secondary_concept.get_input_arity()

    # Check that mapped indices are valid
    if any(i >= secondary_arity for i in indices_to_map.values()) or  any(
        i >= primary_arity for i in indices_to_map.keys()):
        if verbose:
            print("❌ Failed: Invalid variable map indices")
        return False

    # Calculate parameter indices (inputs not mapped from outputs)
    param_inputs = set(
        i for i in range(secondary_arity) if i not in indices_to_map.values()
    )

    # Check that each input is used exactly once
    used_inputs = set(indices_to_map.values()).union(param_inputs)
    if len(used_inputs) != secondary_arity:
        if verbose:
            print("❌ Failed: Invalid variable map indices")
        return False

    # check for one-to-one mapping
    if len(set(indices_to_map.values())) != len(indices_to_map):
        if verbose:
            print("❌ Failed: Invalid variable map indices")
        return False

    # Check type compatibility for mapped outputs to inputs
    # TODO(_; 4/29): Fix when we add more types.
    primary_types = primary_concept.get_component_types()
    secondary_types = secondary_concept.get_component_types()

    return True

class ForallRule(ProductionRule):
    """
    Production rule that creates a new concept by imposing a forall quantifier over arguments of one or two predicates.

    Case 1 (two predicates): Given P(x₁,...,xₙ), Q(y₁,...,yₘ), and indices i₁,...,ik:
      R(unquantified args) := ∀ xᵢ₁ ∈ D₁, ..., xᵢk ∈ Dk, P(x₁,...,xₙ) ⇒ Q(xᵢ₁,...,xᵢk) (k <= m)

    Case 2 (single predicate): Given P(x₁,...,xₙ) and indices i₁,...,iₘ:
      R(unquantified args) := ∀ xᵢ₁ ∈ D₁, ..., xᵢₘ ∈ Dₘ, P(x₁,...,xₙ) (m <= n)

    Domains D₁, ..., Dₘ are inferred from the types. Note(_; 4/25: Right now, only conjecture support)
    """

    def __init__(self, verbose=False, internal_search_timeout: float = DEFAULT_INTERNAL_FORALL_SEARCH_TIMEOUT):
        super().__init__(
            name="forall",
            description=(
                "Creates a new predicate with universal quantification in two modes:\n"
                "1. Two predicates: Given predicates P and Q, creates R where\n"
                "   R(unquantified args) := ∀ quantified args. P(...) ⇒ Q(...)\n"
                "2. Single predicate: Given predicate P, creates R where\n"
                "   R(unquantified args) := ∀ quantified args. P(...)\n"
                "Domains D₁, ..., Dₘ are inferred from example structure."
            ),
            type="Concept",
        )
        self.verbose = verbose
        self.internal_search_timeout = internal_search_timeout
        logger.info(f"ForallRule initialized with internal_search_timeout: {self.internal_search_timeout}s")

    def determine_verification_capabilities(self, *inputs: Entity) -> Tuple[bool, bool]:
        """
        Universal quantification cannot reliably verify positive examples (would require checking
        all possible values) but can reliably verify negative examples (by finding a counterexample).

        This capability is further restricted by the input concept's capabilities:
        - If the input concept cannot reliably verify non-examples, the result also cannot
        - If the input concept cannot reliably verify examples, it doesn't affect
          this rule's inherent inability to verify positive examples

        In the two-predicate case (P => Q):
        - We need both P and Q to be reliable for their respective roles in counterexample verification

        Returns:
            Tuple[bool, bool]: (can_add_examples, can_add_nonexamples)
        """
        if not inputs:
            return False, True

        # The inherent capabilities of universal quantification
        inherent_can_add_examples = False
        inherent_can_add_nonexamples = True

        # For single predicate case (∀x. P(x))
        if len(inputs) == 1:
            input_concept = inputs[0]
            input_can_add_examples = input_concept.can_add_examples
            input_can_add_nonexamples = input_concept.can_add_nonexamples

            # For universal quantification:
            # - We inherently cannot add examples (even if input can)
            # - We can add non-examples only if the input concept can reliably verify non-examples
            can_add_examples = inherent_can_add_examples  # Always False
            can_add_nonexamples = (
                inherent_can_add_nonexamples and input_can_add_nonexamples
            )

            return can_add_examples, can_add_nonexamples

        # For two predicate case (∀x. P(x) => Q(x))
        elif len(inputs) == 2:
            primary_concept, secondary_concept = inputs
            primary_can_add_examples = primary_concept.can_add_examples
            primary_can_add_nonexamples = primary_concept.can_add_nonexamples
            secondary_can_add_examples = secondary_concept.can_add_examples
            secondary_can_add_nonexamples = secondary_concept.can_add_nonexamples

            # For a counterexample to ∀x. P(x) => Q(x), we need:
            # - An x where P(x) is true (requires primary concept to verify examples)
            # - And Q(x) is false (requires secondary concept to verify non-examples)
            can_add_examples = inherent_can_add_examples  # Always False
            can_add_nonexamples = (
                inherent_can_add_nonexamples
                and primary_can_add_examples
                and secondary_can_add_nonexamples
            )

            return can_add_examples, can_add_nonexamples

        # Default case
        return inherent_can_add_examples, inherent_can_add_nonexamples

    def get_input_types(
        self,
    ) -> List[
        List[Tuple[Type[Entity], Optional[Union[ConceptType, List[ConceptType]]]]]
    ]:
        """Return the valid input types for this production rule."""
        return [
            [(Concept, ConceptType.PREDICATE)],  # Single predicate case
            [
                (Concept, ConceptType.PREDICATE),
                (Concept, ConceptType.PREDICATE),
            ],  # Two predicates case
        ]

    def get_valid_parameterizations(self, *inputs: Entity) -> List[Dict[str, Any]]:
        """Generate valid parameter combinations for the given inputs, sampling if necessary."""
        if not inputs:
            return []

        # Case 1: Single predicate (No sampling applied here, typically small)
        if len(inputs) == 1:
            concept = inputs[0]
            arity = concept.get_input_arity()
            valid_parameterizations = [
                {"indices_to_quantify": list(indices)}
                for size in range(1, arity + 1)
                for indices in itertools.combinations(range(arity), size)
            ]
            # print(f"Generated {len(valid_parameterizations)} parameterizations for single predicate.")
            return valid_parameterizations

        # Case 2: Two predicates
        elif len(inputs) == 2:
            primary, secondary = inputs
            n = primary.get_input_arity()
            m = secondary.get_input_arity()
            min_arity = min(n, m)

            # --- Calculate total number of parameterizations --- 
            counts_per_k = []
            total_parameterizations = 0
            for k in range(1, min_arity + 1):
                num_mappings_k = _combinations(n, k) * _combinations(m, k) * _factorial(k)
                if num_mappings_k == 0:
                    counts_per_k.append(0)
                    continue
                num_quantifications_k = 1 << (n + m - k) # 2^(n+m-k)
                count_k = num_mappings_k * num_quantifications_k
                counts_per_k.append(count_k)
                total_parameterizations += count_k
            
            # Include case k=0 (no shared vars) if we allow pure implication without quantification
            # For now, matching previous logic requires k >= 1
            # If k=0 was allowed: 
            #   num_mappings_0 = 1 (the empty map)
            #   num_quantifications_0 = 1 << (n + m) 
            #   count_0 = num_quantifications_0
            #   total_parameterizations += count_0 
            #   counts_per_k.insert(0, count_0) # Add at the beginning

            # print(f"Total possible parameterizations calculated: {total_parameterizations} with primary arity {n} and secondary arity {m}")

            # --- Decide whether to sample or generate all ---
            if total_parameterizations == 0:
                 return [] # No valid parameterizations possible
                 
            if total_parameterizations <= NUM_PARAMETERIZATIONS:
                # print(f"Generating all {total_parameterizations} parameterizations.")
                # Fallback to original generation logic if total is small
                valid_parameterizations = self._generate_all_two_predicate_params(primary, secondary)
                # Check all generated parameterizations
                for param in valid_parameterizations:
                    if not self.can_apply(primary, secondary, **param):
                        # This should ideally not happen if generation logic is correct
                        print(f"Warning: Generated parameterization failed can_apply check: {param}")
                        # raise ValueError("Invalid parameterization found during generation check.") 
                return valid_parameterizations
            else:
                # print(f"Sampling {NUM_PARAMETERIZATIONS} parameterizations from {total_parameterizations}.")
                return self._sample_two_predicate_params(primary, secondary, counts_per_k, total_parameterizations)

    def _generate_all_two_predicate_params(self, primary: Concept, secondary: Concept) -> List[Dict[str, Any]]:
        """Generates all valid parameterizations for the two-predicate case."""
        # This is the original logic moved to a helper function
        primary_arity = primary.get_input_arity()
        secondary_arity = secondary.get_input_arity()
        valid_parameterizations = []
        valid_mappings = []

        def generate_mappings(var1_idx, current_map, used_vars2):
             if var1_idx == primary_arity:
                 return [{"indices_to_map": current_map.copy()}]
             results = generate_mappings(var1_idx + 1, current_map.copy(), used_vars2.copy())
             for var2_idx in range(secondary_arity):
                 if var2_idx not in used_vars2:
                     new_map = current_map.copy()
                     new_map[var1_idx] = var2_idx
                     new_used = used_vars2.copy()
                     new_used.add(var2_idx)
                     results.extend(generate_mappings(var1_idx + 1, new_map, new_used))
             return results

        all_mappings = generate_mappings(0, {}, set())
        all_mappings = [m for m in all_mappings if m["indices_to_map"]] # k >= 1
        all_mappings = [m for m in all_mappings if len(set(m["indices_to_map"].values())) == len(m["indices_to_map"])]
        valid_mappings.extend(all_mappings)

        for mapping_dict in valid_mappings:
            mapping = mapping_dict["indices_to_map"]
            k = len(mapping)
            distinct_vars_count = primary_arity + secondary_arity - k
            # Generate all non-empty subsets of indices to quantify
            for size in range(1, distinct_vars_count + 1):  # Start from 1 to ensure non-empty
                for indices in itertools.combinations(range(distinct_vars_count), size):
                    valid_parameterizations.append({
                        "indices_to_map": mapping,
                        "indices_to_quantify": list(indices)
                    })
        return valid_parameterizations

    def _sample_two_predicate_params(self, primary: Concept, secondary: Concept, counts_per_k: List[int], total_parameterizations: int) -> List[Dict[str, Any]]:
        """Samples NUM_PARAMETERIZATIONS parameterizations uniformly without generating all."""
        n = primary.get_input_arity()
        m = secondary.get_input_arity()
        min_arity = min(n, m)
        sampled_list = []
        sampled_keys = set()
        
        # Adjust counts_per_k to match k values (indices 0 to min_arity-1 correspond to k=1 to min_arity)
        mapping_sizes = list(range(1, min_arity + 1))
        
        if not counts_per_k or sum(counts_per_k) == 0:
            print("Warning: No parameterizations to sample from based on counts.")
            return []

        max_sampling_attempts = total_parameterizations * 2 # Heuristic limit to prevent infinite loops
        attempts = 0
        while len(sampled_list) < NUM_PARAMETERIZATIONS:
            attempts += 1
            if attempts > max_sampling_attempts:
                # print(f"Warning: Exceeded maximum sampling attempts ({max_sampling_attempts}). Returning {len(sampled_list)} samples.")
                break # Safety break

            # 1. Choose mapping size 'k' based on weights
            chosen_k = random.choices(mapping_sizes, weights=counts_per_k, k=1)[0]

            # 2. Generate a random mapping of size 'k'
            # Ensure we don't get stuck if generation fails repeatedly (though unlikely for k >= 1)
            map_attempts = 0
            while map_attempts < 100: # Safety break
                primary_indices = random.sample(range(n), k=chosen_k)
                secondary_indices = random.sample(range(m), k=chosen_k)
                random.shuffle(secondary_indices) # Create random permutation
                indices_to_map = dict(zip(primary_indices, secondary_indices))
                # Basic check: Ensure it's a valid mapping structure (already ensured by sampling k from n and m)
                if len(indices_to_map) == chosen_k: 
                    break
                map_attempts += 1
            else:
                print("Warning: Failed to generate a valid random map after multiple attempts.")
                continue # Try sampling again

            # 3. Generate random quantification subset - ensure it's non-empty
            distinct_vars_count = n + m - chosen_k
            # Generate a random size between 1 and distinct_vars_count
            quant_size = random.randint(1, distinct_vars_count)
            indices_to_quantify = random.sample(range(distinct_vars_count), k=quant_size)
            
            # 4. Store if unique
            param = {"indices_to_map": indices_to_map, "indices_to_quantify": indices_to_quantify}
            # Create a hashable key for the set
            param_key = (tuple(sorted(indices_to_map.items())), tuple(sorted(indices_to_quantify)))

            if param_key not in sampled_keys:
                # Final check if the sampled param is valid according to can_apply
                if not self.can_apply(primary, secondary, **param):
                    print(f"Warning: Sampled an invalid parameterization according to can_apply: {param}. Skipping.")
                    continue # Skip this sample and try again
                    
                sampled_keys.add(param_key)
                sampled_list.append(param)
                # print(f"Sampled {len(sampled_list)}/{NUM_PARAMETERIZATIONS}") # Debug print
        
        return sampled_list

    def can_apply(
        self,
        *inputs: Entity,
        indices_to_quantify: Optional[List[int]] = None,
        indices_to_map: Optional[Dict[int, int]] = None,
        verbose: bool = False,
    ) -> bool:
        """
        Check if this rule can be applied to the inputs.

        Requirements for both cases:
        1. Indices to quantify must be valid and non-empty
        2. If domains provided, length must match number of indices

        Case 1 (two predicates):
        1. Primary concept must have arity n ≥ 2
        2. Secondary concept must have arity m ≥ 1
        3. Number of indices must equal m

        Case 2 (single predicate):
        1. Concept must have arity n ≥ 1
        2. At least one index to quantify over
        3. Not all indices can be quantified (must have at least one unquantified)
        """
        if verbose:
            print("\nChecking if ForallRule can be applied:")
            print(f"Number of inputs: {len(inputs)}")
            if indices_to_map:
                print(f"Indices to map: {indices_to_map}")
                print(f"Indices_to_quantify: {indices_to_quantify}")
            else:
                print(f"Indices to quantify: {indices_to_quantify}")

        # Validate input types
        if len(inputs) == 1:
            if (
                not isinstance(inputs[0], Concept)
                or inputs[0].examples.example_structure.concept_type
                != ConceptType.PREDICATE
            ):
                if verbose:
                    print("❌ Single input must be a predicate concept")
                return False
        elif len(inputs) == 2:
            if not all(
                isinstance(inp, Concept)
                and inp.examples.example_structure.concept_type == ConceptType.PREDICATE
                for inp in inputs
            ):
                if verbose:
                    print("❌ Both inputs must be predicate concepts")
                return False
        else:
            if verbose:
                print("❌ ForallRule requires either one or two input concepts")
            return False

        # Check that indices_to_quantify is not empty
        if not indices_to_quantify:
            if verbose:
                print("❌ Must specify at least one index to quantify.")
            return False

        # Rest of the existing can_apply logic...
        if len(inputs) == 1:
            # Case 2: Single predicate
            concept = inputs[0]

            # Concept must be a predicate with arity n ≥ 1
            n = concept.get_input_arity()
            if verbose:
                print(f"Input arity: {n}")
            if n < 1:
                if verbose:
                    print("❌ Concept must have arity at least 1.")
                return False

            # Check each index is valid
            for idx in indices_to_quantify:
                if not (idx in range(n)):
                    if verbose:
                        print(
                            f"❌ Index {idx} is out of range for concept with arity {n}."
                        )
                    return False
            return True

        elif len(inputs) == 2:
            # Case 1: Two predicates (original case)
            primary, secondary = inputs

            # Primary must be a predicate with arity n ≥ 1
            n = primary.get_input_arity()
            if n < 1:
                if verbose:
                    print("❌ Primary concept must have arity at least 2.")
                return False

            # Secondary must be a predicate; let m be its arity
            m = secondary.get_input_arity()
            if m < 1:
                if verbose:
                    print("❌ Secondary concept must have arity at least 1.")
                return False

            # Check that each index is valid for the primary concept
            if not _validate_mapping(
                indices_to_map, primary, secondary, verbose
            ):
                return False

            # Check that the indices_to_quantify is valid

            remaining_new_indices = set(range(n + m - len(indices_to_map.keys())))
            if not all([i in remaining_new_indices for i in indices_to_quantify]):
                if verbose:
                    print("❌ Indices to quantify must be a subset of the remaining indices.")
                return False

            return True

    def apply(
        self,
        *inputs: Entity,
        indices_to_quantify: List[int] = None,
        indices_to_map: Dict[int, int] = None,
    ) -> Entity:
        """
        Apply the forall rule to create a new concept.

        Case 1 (two predicates):
          takes in two predicates P and Q, and a mapping from the indices of P to the indices of Q, and a list of which remaining indices to quantify over.
          R(unquantified args) := ∀ xᵢ₁ ∈ D₁, ..., xᵢₘ ∈ Dₘ, P(x₁,...,xₙ) ⇒ Q(xᵢ₁,...,xᵢₘ)

        Case 2 (single predicate):
          takes in a single predicate P and a list of which indices to quantify over.
          R(unquantified args) := ∀ xᵢ₁ ∈ D₁, ..., xᵢₘ ∈ Dₘ, P(x₁,...,xₙ)
        """
        if not self.can_apply(*inputs, indices_to_quantify=indices_to_quantify,
                              indices_to_map=indices_to_map):
            raise ValueError("Cannot apply ForallRule to these inputs")

        if len(inputs) == 1:
            # Case 2: Single predicate
            return self._apply_single_predicate(inputs[0], indices_to_quantify)
        else:
            # Case 1: Two predicates
            return self._apply_two_predicates(inputs[0], inputs[1], indices_to_quantify, indices_to_map)

    def _apply_single_predicate(
        self,
        concept: Entity,
        indices_to_quantify: List[int],
    ) -> Entity:
        """Apply the forall rule to a single predicate."""


        concept_arity = concept.get_input_arity()

        if len(indices_to_quantify) < concept_arity:
            # Determine which indices to keep (not quantified)
            kept_indices = [i for i in range(concept_arity) if i not in indices_to_quantify]

            # Infer domains
            # domains = self._infer_domains(concept, indices_to_quantify)

            # Create a new concept with reduced arity
            concept_types = concept.get_component_types()
            new_types = tuple(concept_types[i] for i in kept_indices)

            # Get verification capabilities from the rule
            can_add_examples, can_add_nonexamples = (
                self.determine_verification_capabilities(concept)
            )

            new_concept = Concept(
                name=f"forall_({concept.name}_indices_to_quantify_{indices_to_quantify})",
                description=f"Universal quantification over {concept.name}",
                symbolic_definition=lambda *args: self._build_single_forall_expr(
                    concept, args, indices_to_quantify, kept_indices#, domains
                ),
                computational_implementation=self._build_computational_impl(
                    concept,
                    indices=indices_to_quantify,
                    kept_indices=kept_indices,
                ),
                example_structure=ExampleStructure(
                    concept_type=ConceptType.PREDICATE,
                    component_types=new_types,
                    input_arity=len(kept_indices),
                ),
                can_add_examples=can_add_examples,
                can_add_nonexamples=can_add_nonexamples,
                z3_translation = (lambda *args: (
                    self._z3_translate_forall_single_predicate(
                        concept, indices_to_quantify, *args
                    )
                )) if concept.has_z3_translation()
                else None,
            )

            # Store metadata for use in transform_examples
            new_concept._indices = indices_to_quantify
            new_concept._kept_indices = kept_indices
            new_concept._is_implication = False
            new_concept.map_iterate_depth = concept.map_iterate_depth

            # Transform examples
            self._transform_examples(
                new_concept=new_concept,
                primary=concept,
                secondary=None,
                indices_to_quantify=indices_to_quantify,
                kept_indices=kept_indices,
                var_info=None, # Not needed for single predicate case
                parameter_var_ids=None, # Use kept_indices instead
                quantified_var_ids=None  # Use indices_to_quantify instead
            )

            return new_concept
        
        else: # quantifying over every arg, create a conjecture
            # Infer domains
            # domains = self._infer_domains(concept, indices_to_quantify)

            # Get verification capabilities from the rule
            can_add_examples, can_add_nonexamples = (
                self.determine_verification_capabilities(concept)
            )

            # Create variables for quantification
            vars = [Var(f"x{i}") for i in range(concept_arity)]
            
            # Wrap with universal quantifiers
            expr = ConceptApplication(concept, vars)
            for i in reversed(range(concept_arity)):
                expr = Forall(f"x{i}", NatDomain(), expr)
                # expr = Forall(f"x{i}", domains, expr)
        
            new_conjecture = Conjecture(
                name=f"forall_({concept.name}_indices_to_quantify_{indices_to_quantify})",
                description=f"Universal quantification conjecture over {concept.name}",
                symbolic_definition=lambda: expr,
                example_structure=ExampleStructure(
                    concept_type=ConceptType.PREDICATE,
                    component_types=(),
                    input_arity=0
                ),
                can_add_examples=can_add_examples,
                can_add_nonexamples=can_add_nonexamples,
                z3_translation = (lambda *args: (
                    self._z3_translate_forall_single_predicate(
                        concept, indices_to_quantify, *args
                    )
                )) if concept.has_z3_translation()
                else None,
            )
            new_conjecture.map_iterate_depth = concept.map_iterate_depth
            return new_conjecture

    def _z3_translate_forall_single_predicate(self, concept, indices_to_quantify, *args):
        template = concept.to_z3(*([None] * concept.get_input_arity()))
        program = template.program

        code = f"""
        params {program.params - len(indices_to_quantify)};
        bounded params {len(indices_to_quantify)};
        """

        code += f"""
        p_0 := Pred(
            {program.dsl()}
        );
        """

        existential_quantifier = "[" + ", ".join([f"b_{i}" for i in range(len(indices_to_quantify))]) + "]"

        params = [None] * program.params
        
        x_params_start = 0 
        b_params_start = 0
        for i in range(program.params):
            if i not in indices_to_quantify: 
                params[i] = f"x_{x_params_start}"
                x_params_start += 1
            else:
                params[i] = f"b_{b_params_start}"
                b_params_start += 1
        
        args_string = _format_args(params)
        code += f"""
        ReturnExpr None;
        ReturnPred ForAll({existential_quantifier}, p_0({args_string}));
        """

        template = Z3Template(code)
        template.set_args(*args)
        return template



    def _apply_two_predicates(
        self,
        primary: Entity,
        secondary: Entity,
        indices_to_quantify: List[int],
        indices_to_map: Dict[int, int],
    ) -> Entity:
        """Apply the forall rule to two predicates (P => Q)."""
        assert indices_to_map is not None # indices_to_quantify can be empty

        primary_arity = primary.get_input_arity()
        primary_types = primary.get_component_types()
        secondary_arity = secondary.get_input_arity()
        secondary_types = secondary.get_component_types()

        # 1. Identify and categorize all variables
        var_info = {}
        distinct_logical_vars_ordered_ids = [] # Maintain order [primary unique, secondary unique, shared]
        var_counter = 0

        # Store primary unique vars
        primary_unique_indices = set(range(primary_arity)) - set(indices_to_map.keys())
        for i in sorted(list(primary_unique_indices)):
            var_id = f"v{var_counter}"
            var_info[var_id] = {
                "type": primary_types[i],
                "origin": "primary_unique",
                "primary_idx": i,
                "role": None, # Determined later
                "var_obj": Var(var_id)
            }
            distinct_logical_vars_ordered_ids.append(var_id)
            var_counter += 1

        # Store secondary unique vars
        mapped_secondary_indices = set(indices_to_map.values())
        secondary_unique_indices = set(range(secondary_arity)) - mapped_secondary_indices
        for i in sorted(list(secondary_unique_indices)):
            var_id = f"v{var_counter}"
            var_info[var_id] = {
                "type": secondary_types[i],
                "origin": "secondary_unique",
                "secondary_idx": i,
                "role": None, # Determined later
                "var_obj": Var(var_id)
            }
            distinct_logical_vars_ordered_ids.append(var_id)
            var_counter += 1

        # Store shared vars (use primary index as reference)
        for i in sorted(indices_to_map.keys()):
            var_id = f"v{var_counter}"
            secondary_idx = indices_to_map[i]
            # Type check (basic, assumes types match if mapped)
            if primary_types[i] != secondary_types[secondary_idx]:
                raise TypeError(f"Type mismatch for shared variable: P[{i}] ({primary_types[i]}) != Q[{secondary_idx}] ({secondary_types[secondary_idx]})")
            var_info[var_id] = {
                "type": primary_types[i], # Type from primary
                "origin": "shared",
                "primary_idx": i,
                "secondary_idx": secondary_idx,
                "role": None, # Determined later
                "var_obj": Var(var_id)
            }
            distinct_logical_vars_ordered_ids.append(var_id)
            var_counter += 1

        # Total number of distinct logical variables
        num_distinct_vars = len(distinct_logical_vars_ordered_ids)

        # 2. Assign roles (quantified or parameter) based on indices_to_quantify
        # indices_to_quantify now refers to indices in distinct_logical_vars_ordered_ids
        quantified_var_ids = set()
        parameter_var_ids = [] # Keep order for concept args

        if indices_to_quantify is None:
             indices_to_quantify = [] # Treat None as empty list

        # --- Role Assignment V3: Based on Iterating indices_to_quantify --- 
        target_quantified_origins = {} # Store {origin_key: True} for vars to be quantified
                                     # origin_key format: ('P', index) or ('Q', index)
        
        # Determine which unique secondary indices correspond to higher quantify indices
        mapped_secondary_indices = set(indices_to_map.values())
        secondary_unique_indices = sorted(list(set(range(secondary_arity)) - mapped_secondary_indices))

        for q_idx in indices_to_quantify:
            if q_idx < primary_arity:
                # Corresponds to P[q_idx]
                target_quantified_origins[('P', q_idx)] = True
            else:
                # Corresponds to a unique secondary variable
                secondary_q_map_idx = q_idx - primary_arity
                if secondary_q_map_idx < len(secondary_unique_indices):
                    actual_secondary_idx = secondary_unique_indices[secondary_q_map_idx]
                    target_quantified_origins[('Q', actual_secondary_idx)] = True
                else:
                    # This index is out of bounds for unique secondary vars
                    print(f"Warning: index_to_quantify {q_idx} is out of range for unique secondary variables.")

        # Now assign roles based on the target_quantified_origins map
        quantified_var_ids = set()
        parameter_var_ids = []
        for var_id, info in var_info.items():
            is_quantified = False
            if info['origin'] == 'primary_unique' or info['origin'] == 'shared':
                if ('P', info['primary_idx']) in target_quantified_origins:
                    is_quantified = True
            elif info['origin'] == 'secondary_unique':
                 if ('Q', info['secondary_idx']) in target_quantified_origins:
                      is_quantified = True
                      
            if is_quantified:
                 info["role"] = "quantified"
                 quantified_var_ids.add(var_id)
            else:
                 info["role"] = "parameter"
                 parameter_var_ids.append(var_id)
        # --- End Role Assignment V3 ---

        # 3. Determine if it's a Concept or Conjecture
        is_conjecture = (len(parameter_var_ids) == 0)
        new_arity = len(parameter_var_ids)
        new_types = tuple(var_info[var_id]["type"] for var_id in parameter_var_ids)

        # 4. Get verification capabilities
        can_add_examples, can_add_nonexamples = (
            self.determine_verification_capabilities(primary, secondary)
        )

        # 5. Build Symbolic Definition
        symbolic_definition_lambda = self._build_symbolic_definition_two_pred(
            primary, secondary, var_info, quantified_var_ids, parameter_var_ids, is_conjecture
        )

        # 6. Build Computational Implementation
        computational_impl_lambda = self._build_computational_impl_two_pred(
             primary, secondary, var_info, quantified_var_ids, parameter_var_ids
        )

        # 7. Create Concept or Conjecture
        name = f"forall_({primary.name}_with_{secondary.name}_indices_to_map_{indices_to_map}_indices_to_quantify_{indices_to_quantify})"
        description = f"Universal quantification over {primary.name} implying {secondary.name}"

        if is_conjecture:
            new_entity = Conjecture(
                name=name,
                description=description,
                symbolic_definition=symbolic_definition_lambda,
                example_structure=ExampleStructure(
                    concept_type=ConceptType.PREDICATE,
                    component_types=(),
                    input_arity=0
                ),
                can_add_examples=can_add_examples,
                can_add_nonexamples=can_add_nonexamples,
                z3_translation = (lambda *args: (
                    self._z3_translate_forall_two_predicates(
                        primary, secondary, indices_to_quantify, indices_to_map, *args
                    )
                )) if primary.has_z3_translation() and secondary.has_z3_translation()
                else None,
            )
        else:
            new_entity = Concept(
                name=name,
                description=description,
                symbolic_definition=symbolic_definition_lambda,
                computational_implementation=computational_impl_lambda,
                example_structure=ExampleStructure(
                    concept_type=ConceptType.PREDICATE,
                    component_types=new_types,
                    input_arity=new_arity,
                ),
                can_add_examples=can_add_examples,
                can_add_nonexamples=can_add_nonexamples,
                z3_translation = (lambda *args: (
                    self._z3_translate_forall_two_predicates(
                        primary, secondary, indices_to_quantify, indices_to_map, *args
                    )
                )) if primary.has_z3_translation() and secondary.has_z3_translation()
                else None,
            )
            
            # Note(_; 4/30): In theory we could use this to invalidate conjectures.
            # Transform examples using available data
            self._transform_examples(
                new_concept=new_entity,
                primary=primary,
                secondary=secondary,
                indices_to_quantify=None, # Use quantified_var_ids instead
                kept_indices=None,      # Use parameter_var_ids instead
                var_info=var_info,
                parameter_var_ids=parameter_var_ids,
                quantified_var_ids=quantified_var_ids
            )

        return new_entity

    def _infer_domains(self, var_types: List[ExampleType]) -> List[Entity]:
        """Infer domains for the given variable types."""
        # map from component type to domain
        component_to_domain_map = {
            ExampleType.NUMERIC: NatDomain(),
            ExampleType.GROUPELEMENT: GroupElementDomain(),
            ExampleType.GROUP: GroupDomain(),
            ExampleType.SET: SetDomain(),
            ExampleType.FUNCTION: NotImplementedError(
                "Function domains not handled yet"
            ),
        }
        domains = []
        # Infer domains for each type
        for var_type in var_types:
            domain = component_to_domain_map.get(var_type)
            if domain is None:
                 raise NotImplementedError(f"Domain inference not implemented for type {var_type}")
            domains.append(domain)
        return domains
    
    def _z3_translate_forall_two_predicates(
            self,
            primary,
            secondary, 
            indices_to_quantify,
            indices_to_map,
            *args
        ):
        primary_template = primary.to_z3(*([None] * primary.get_input_arity()))
        primary_program = primary_template.program

        secondary_template = secondary.to_z3(*([None] * secondary.get_input_arity()))
        secondary_program = secondary_template.program

        existential_quantifier = "[" + ", ".join([f"b_{i}" for i in range(len(indices_to_quantify))]) + "]"

        mapped_q = list(indices_to_map.values())
        unmatched_q = [j for j in range(secondary_program.params) if j not in mapped_q]

        # 2. Build the merged list (“P vars then unmatched Q vars”)
        merged = [("primary", i) for i in range(primary_program.params)] + [("secondary", j) for j in unmatched_q]

        # 3. Assign names x0,x1,… or b0,b1,… along merged order
        name_table, x_ctr, b_ctr = {}, 0, 0
        quantify_set = indices_to_quantify

        for idx, var in enumerate(merged):
            if idx in quantify_set:
                name = f"b_{b_ctr}"
                b_ctr += 1
            else:
                name = f"x_{x_ctr}"
                x_ctr += 1
            name_table[var] = name

        # 4. Primary args: just look up ("P", i)
        primary_args = [name_table[("primary", i)] for i in range(primary_program.params)]

        # 5. Secondary args
        secondary_args = [""] * secondary_program.params
        for j in range(secondary_program.params):
            # If shared, copy the P counterpart’s name
            if j in mapped_q:
                i = {v: k for k, v in indices_to_map.items()}[j]
                secondary_args[j] = name_table[("primary", i)]
            else:  # unmatched
                secondary_args[j] = name_table[("secondary", j)]


        primary_args_string = _format_args(primary_args)
        secondary_args_string = _format_args(secondary_args)

        code = f"""
        params {len(set([i for i in primary_args + secondary_args if i.startswith("x_")]))};
        bounded params {len(indices_to_quantify)};
        """

        code += f"""
        p_0 := Pred(
            {primary_program.dsl()}
        );
        p_1 := Pred(
            {secondary_program.dsl()}
        );
        """

        code += f"""
        ReturnExpr None;
        ReturnPred ForAll({existential_quantifier}, Implies(p_0({primary_args_string}), p_1({secondary_args_string})));
        """

        template = Z3Template(code)
        template.set_args(*args)
        return template


    def _build_single_forall_expr(
        self,
        concept: Concept,
        args: Tuple[Any, ...],
        indices: List[int],
        kept_indices: List[int],
        # domains: List[Entity],
    ) -> Expression:
        """Build symbolic expression with universal quantifiers for a single predicate."""
        # Map the unquantified (kept) arguments to their positions
        arg_map = {}
        for i, kept_idx in enumerate(kept_indices):
            arg_map[kept_idx] = args[i]

        # Build the predicate arguments
        concept_args = []
        for i in range(concept.get_input_arity()):
            if i in indices:
                # Use Var for quantified indices
                concept_args.append(Var(f"x{i}"))
            else:
                # Use the actual argument for unquantified indices
                concept_args.append(arg_map[i])

        # Create the inner expression: P(...)
        expr = ConceptApplication(concept, *concept_args)

        # Wrap with universal quantifiers (starting from the innermost)
        for i, idx in enumerate(reversed(indices)):
            domain_idx = len(indices) - i - 1  # Adjust for reversed order
            expr = Forall(f"x{idx}", NatDomain(), expr)

        return expr

    def _build_symbolic_definition_two_pred(
        self,
        primary: Concept,
        secondary: Concept,
        var_info: Dict[str, Dict[str, Any]],
        quantified_var_ids: set[str],
        parameter_var_ids: List[str],
        is_conjecture: bool
    ) -> callable:
        """Builds the symbolic definition lambda for the two-predicate case."""

        def build_expression(*param_values):
            # Create argument lists for P and Q applications
            primary_args = [None] * primary.get_input_arity()
            secondary_args = [None] * secondary.get_input_arity()
            value_map = {} # Maps var_id to Var obj or parameter value

            # Populate value_map with parameter values first
            for i, var_id in enumerate(parameter_var_ids):
                 value_map[var_id] = param_values[i]

            # Populate value_map with Var objects for quantified/shared vars
            # And build argument lists
            for var_id, info in var_info.items():
                if var_id not in value_map: # Quantified or shared
                     value_map[var_id] = info["var_obj"]

                # Assign to primary_args if applicable
                if "primary_idx" in info:
                    primary_args[info["primary_idx"]] = value_map[var_id]

                # Assign to secondary_args if applicable
                if "secondary_idx" in info:
                    secondary_args[info["secondary_idx"]] = value_map[var_id]

            # Create the core implication P(...) => Q(...)
            antecedent = ConceptApplication(primary, *primary_args)
            consequent = ConceptApplication(secondary, *secondary_args)
            implication = Implies(antecedent, consequent)

            # Wrap with Forall quantifiers
            expr = implication
            quantified_vars_ordered = sorted(list(quantified_var_ids)) # Consistent order
            quantified_types = [var_info[vid]["type"] for vid in quantified_vars_ordered]
            # domains = self._infer_domains(quantified_types)

            for i, var_id in enumerate(reversed(quantified_vars_ordered)):
                var_obj = var_info[var_id]["var_obj"]
                domain = NatDomain()
                expr = Forall(var_obj.name, domain, expr) # Use var_obj.name for consistency

            return expr

        if is_conjecture:
            # For conjectures, build the expression immediately
            static_expr = build_expression()
            return lambda: static_expr
        else:
            # For concepts, return the function that takes params and builds expression
            return build_expression

    def _build_computational_impl_two_pred(
        self,
        primary: Concept,
        secondary: Concept,
        var_info: Dict[str, Dict[str, Any]], # Keep var_info for reference if needed, but don't rely solely on its roles
        quantified_var_ids: set[str],    # Use this to know *which* logical vars were quantified
        parameter_var_ids: List[str]    # Use this to map param_values to logical vars
    ) -> callable:
        """Builds the computational implementation lambda for the two-predicate case."""
        # --- Get original parameterization info (needed for direct reconstruction) ---
        # This assumes the instance stores this info; if not, it needs to be passed down.
        # We might need to adjust how `apply` calls this if params aren't stored.
        # For now, let's assume we can reconstruct which indices were mapped/quantified.

        # Simplified: Let's extract necessary info directly from var_info structure if possible
        # We need to know: which primary/secondary indices are parameters, which are quantified,
        # and how parameter values map to parameter indices.

        param_indices_map = {} # Maps parameter_var_id -> {'P_idx': [], 'Q_idx': []}
        quantified_indices_map = {} # Maps quantified_var_id -> {'P_idx': [], 'Q_idx': []} # Should only have P_idx based on convention

        for var_id, info in var_info.items():
            role = info.get("role") # Get role assigned by V3 logic
            if role == "parameter":
                map_entry = {'P_idx': [], 'Q_idx': []}
                if "primary_idx" in info: map_entry['P_idx'].append(info["primary_idx"])
                if "secondary_idx" in info: map_entry['Q_idx'].append(info["secondary_idx"])
                param_indices_map[var_id] = map_entry
            elif role == "quantified":
                 map_entry = {'P_idx': [], 'Q_idx': []}
                 if "primary_idx" in info: map_entry['P_idx'].append(info["primary_idx"])
                 if "secondary_idx" in info: map_entry['Q_idx'].append(info["secondary_idx"]) # Should be empty for convention?
                 quantified_indices_map[var_id] = map_entry

        ordered_quantified_ids = sorted(list(quantified_var_ids))
        num_quantified = len(ordered_quantified_ids)
        primary_arity = primary.get_input_arity()
        secondary_arity = secondary.get_input_arity()
        
        start_time_outer = time.time() # Record start time for potential outer logging

        def impl(*param_values):
            # print(f"DEBUG: Computing impl for {primary.name}=>{secondary.name} with params: {param_values}")

            # Map parameter values to their logical var_ids
            parameter_value_map = {var_id: param_values[i] for i, var_id in enumerate(parameter_var_ids)}

            start_time = time.time() # Start timer for the search loop
            for i in range(MAX_SEARCH_LIMIT):
                # Check for internal timeout
                if time.time() - start_time > self.internal_search_timeout:
                    logger.warning(f"ForallRule computational_impl (two predicate) internal search timed out after {self.internal_search_timeout}s for {primary.name}=>{secondary.name}")
                    break # Exit search loop, returning True (no counterexample found in time)

                # Sample values for quantified variables
                quantified_sample_values = {}
                try:
                    quantified_sample_values_list = [random.randint(0, MAX_VALUE_PER_VAR) for _ in range(num_quantified)]
                    quantified_sample_values = {ordered_quantified_ids[j]: val for j, val in enumerate(quantified_sample_values_list)}
                    # print(f"DEBUG: Sample {i}: Quantified values = {quantified_sample_values}")
                except TypeError:
                     print(f"Warning: Cannot perform random sampling...")
                     return True
                primary_args = [None] * primary_arity
                secondary_args = [None] * secondary_arity

                # Place parameter values
                for param_var_id, param_val in parameter_value_map.items():
                    indices = param_indices_map.get(param_var_id, {})
                    for p_idx in indices.get('P_idx', []):
                        if p_idx < primary_arity: primary_args[p_idx] = param_val
                    for q_idx in indices.get('Q_idx', []):
                         if q_idx < secondary_arity: secondary_args[q_idx] = param_val

                # Place quantified values
                for quant_var_id, quant_val in quantified_sample_values.items():
                    indices = quantified_indices_map.get(quant_var_id, {})
                    for p_idx in indices.get('P_idx', []):
                         if p_idx < primary_arity: primary_args[p_idx] = quant_val
                    for q_idx in indices.get('Q_idx', []): # Should be empty based on convention?
                         if q_idx < secondary_arity: secondary_args[q_idx] = quant_val

                # --- End Direct Argument Reconstruction ---

                # print(f"DEBUG: Sample {i}: Primary args = {primary_args}, Secondary args = {secondary_args}")

                # Check implication: P => Q. Counterexample is P=True and Q=False
                try:
                    # Check if all arguments are filled (basic sanity check)
                    if None in primary_args or None in secondary_args:
                        #  print(f"DEBUG: Skipping combination {i} due to None in args: P={primary_args}, Q={secondary_args}")
                         continue

                    p_result = primary.compute(*primary_args)
                    # print(f"DEBUG: Sample {i}: P computed {p_result}")
                    if p_result:
                        s_result = secondary.compute(*secondary_args)
                        # print(f"DEBUG: Sample {i}: Q computed {s_result}")
                        if not s_result:
                            # Found counterexample
                            # print(f"DEBUG: Found counterexample, returning False.")
                            return False
                except Exception as e:
                    # print(f"DEBUG: Skipping combination {i} due to error: {e}")
                    continue

            # No counterexample found after sampling
            # print(f"DEBUG: No counterexample found after {MAX_SEARCH_LIMIT} samples, returning True.")
            return True

        return impl

    def _transform_examples(
        self,
        new_concept: Entity,
        primary: Concept,
        secondary: Optional[Concept] = None,
        # Args for single-predicate case (legacy-like)
        indices_to_quantify: Optional[List[int]] = None,
        kept_indices: Optional[List[int]] = None,
        # Args for two-predicate case (new structure)
        var_info: Optional[Dict[str, Dict[str, Any]]] = None,
        parameter_var_ids: Optional[List[str]] = None,
        quantified_var_ids: Optional[set[str]] = None
    ):
        """
        Use examples/non-examples from input concepts to find non-examples
        for the newly created universally quantified concept.

        - For ∀x.P(x, params): Finds params where P(x, params) is known false for some x.
        - For ∀q. (P(q, params) => Q(q, params)): Finds params where P(q, params) is known true
          and Q(q, params) is known false for the same q.
        """
        if self.verbose:
            print(f"\nAttempting to transform examples/non-examples for {new_concept.name}")

        # --- Single Predicate Case: ∀x. P(x, params) --- 
        if secondary is None:
            assert kept_indices is not None
            assert indices_to_quantify is not None
            if not primary.examples: # No examples to process
                return

            # Derive non-examples for the new concept:
            # If P(q, p) is false for *any* q (given p),
            # then ∀q.P(q, p) is false for that p.
            # So, extract the parameters (kept_indices) from each primary non-example.
            concept_arity = primary.get_input_arity()
            added_non_examples = set() # Track added non-examples to avoid duplicates

            for ex in primary.examples.get_nonexamples():
                if not isinstance(ex.value, tuple) or len(ex.value) != concept_arity:
                    if self.verbose: print(f"Skipping malformed non-example: {ex.value}")
                    continue # Skip malformed examples

                # Parameters are the values at the kept indices
                params_value = tuple(ex.value[i] for i in kept_indices)

                # Add the parameter tuple as a non-example if not already added
                if params_value not in added_non_examples:
                    if self.verbose:
                        print(f"  Adding non-example (derived from P non-example {ex.value}): {params_value}")
                    new_concept.add_nonexample(params_value)
                    added_non_examples.add(params_value)

        # --- Two Predicate Case: ∀q. (P(...) ⇒ Q(...)) ---
        else:
            assert var_info is not None
            assert parameter_var_ids is not None
            assert quantified_var_ids is not None
            if not primary.examples or not hasattr(secondary.examples, 'get_nonexamples') or not secondary.examples.get_nonexamples():
                 if self.verbose:
                     print("  Skipping two-predicate transform: Missing primary examples or secondary non-examples.")
                 return
            if not secondary.can_add_nonexamples: # Check capability flag
                 if self.verbose:
                     print("  Skipping two-predicate transform: Secondary concept cannot reliably provide non-examples.")
                 return

            # 1. Determine how to extract params and quantified parts from P and Q examples/non-examples
            # This maps parameter_var_ids and quantified_var_ids back to original P/Q indices
            param_indices_p = []
            quant_indices_p = []
            param_indices_q = []
            quant_indices_q = [] 
            ordered_quant_ids = sorted(list(quantified_var_ids))
            
            p_arity = primary.get_input_arity()
            q_arity = secondary.get_input_arity()
            p_reconstruction_map = {i: None for i in range(p_arity)}
            q_reconstruction_map = {i: None for i in range(q_arity)}

            # Map parameters to their original indices
            for param_idx, var_id in enumerate(parameter_var_ids):
                info = var_info[var_id]
                if "primary_idx" in info:
                    p_reconstruction_map[info["primary_idx"]] = ("param", param_idx)
                if "secondary_idx" in info:
                    q_reconstruction_map[info["secondary_idx"]] = ("param", param_idx)
            
            # Map quantified variables to their original indices
            for quant_idx, var_id in enumerate(ordered_quant_ids):
                info = var_info[var_id]
                if "primary_idx" in info:
                    p_reconstruction_map[info["primary_idx"]] = ("quant", quant_idx)
                if "secondary_idx" in info:
                    q_reconstruction_map[info["secondary_idx"]] = ("quant", quant_idx)

            # Helper to extract parts from a full P example
            def extract_p(full_p_value):
                params = [None] * len(parameter_var_ids)
                quants = [None] * len(ordered_quant_ids)
                for p_idx, mapping in p_reconstruction_map.items():
                    if mapping is None: continue
                    type, target_idx = mapping
                    if type == "param":
                        params[target_idx] = full_p_value[p_idx]
                    else: # quant
                        quants[target_idx] = full_p_value[p_idx]
                return tuple(params), tuple(quants)

            # Helper to extract parts from a full Q non-example
            def extract_q(full_q_value):
                params = [None] * len(parameter_var_ids)
                quants = [None] * len(ordered_quant_ids)
                for q_idx, mapping in q_reconstruction_map.items():
                    if mapping is None: continue
                    type, target_idx = mapping
                    if type == "param":
                        params[target_idx] = full_q_value[q_idx]
                    else: # quant
                        quants[target_idx] = full_q_value[q_idx]
                return tuple(params), tuple(quants)

            # --- Refactored Logic: Direct Comparison ---
            p_examples = list(primary.examples.get_examples())
            q_non_examples = list(secondary.examples.get_nonexamples())
            non_example_params = set()
            
            p_arity = primary.get_input_arity()
            q_arity = secondary.get_input_arity()

            if self.verbose:
                print(f"Starting direct comparison: {len(p_examples)} P examples vs {len(q_non_examples)} Q non-examples.")

            for ex_p in p_examples:
                if not isinstance(ex_p.value, tuple) or len(ex_p.value) != p_arity:
                    if self.verbose: print(f"Skipping malformed P example: {ex_p.value}")
                    continue # Skip malformed P example

                for nex_q in q_non_examples:
                    if not isinstance(nex_q.value, tuple) or len(nex_q.value) != q_arity:
                        if self.verbose: print(f"Skipping malformed Q non-example: {nex_q.value}")
                        continue # Skip malformed Q non-example

                    # Check if values at shared/mapped indices match
                    match = True
                    for var_id, info in var_info.items():
                        if info["origin"] == "shared":
                            p_idx = info["primary_idx"]
                            q_idx = info["secondary_idx"]
                            # Check bounds before accessing values
                            if p_idx >= len(ex_p.value) or q_idx >= len(nex_q.value):
                                 if self.verbose: print(f"WARN: Index out of bounds during shared var check. P_idx={p_idx}, Q_idx={q_idx}. P_val={ex_p.value}, Q_val={nex_q.value}")
                                 match = False
                                 break # Cannot compare if indices invalid for the data
                            if ex_p.value[p_idx] != nex_q.value[q_idx]:
                                match = False
                                break # Mismatch found

                    # If all shared variables match, extract parameters and add non-example
                    if match:
                        current_params = [None] * len(parameter_var_ids)
                        valid_params = True
                        for i, param_var_id in enumerate(parameter_var_ids):
                            info = var_info[param_var_id]
                            # Parameters must originate from either P or Q
                            if "primary_idx" in info:
                                p_idx = info["primary_idx"]
                                if p_idx < len(ex_p.value):
                                     current_params[i] = ex_p.value[p_idx]
                                else:
                                     if self.verbose: print(f"WARN: Param {param_var_id} primary_idx {p_idx} out of bounds for P example {ex_p.value}")
                                     valid_params = False; break
                            elif "secondary_idx" in info:
                                q_idx = info["secondary_idx"]
                                if q_idx < len(nex_q.value):
                                    current_params[i] = nex_q.value[q_idx]
                                else:
                                     if self.verbose: print(f"WARN: Param {param_var_id} secondary_idx {q_idx} out of bounds for Q non-example {nex_q.value}")
                                     valid_params = False; break
                            else:
                                 if self.verbose: print(f"WARN: Parameter variable {param_var_id} missing P/Q index in var_info.")
                                 valid_params = False; break

                        # Add the extracted parameter tuple if valid and complete
                        if valid_params and None not in current_params:
                            params_tuple = tuple(current_params)
                            if params_tuple not in non_example_params: # Avoid verbose print for duplicates
                                if self.verbose:
                                     print(f"  Found counterexample pair: P={ex_p.value}, Q_non={nex_q.value}. Adding non-example params: {params_tuple}")
                                non_example_params.add(params_tuple)
                                # Optional: break # from inner nex_q loop if only one counterexample needed per ex_p

            # Add all unique found non-examples to the new concept
            if self.verbose:
                print(f"Adding {len(non_example_params)} unique non-examples derived from P/Q pairs.")
            for params_tuple in non_example_params:
                new_concept.add_nonexample(params_tuple)

            # --- End of Refactored Logic ---

            # Remove old aggregation logic and the loop that used it
            # Commented out instead of deleted for reference:
            # # 2. Aggregate P examples (True P) grouped by parameters
            # p_true_qvals_by_params = {}
            # ... (old code for p_true_qvals_by_params) ...

            # # 3. Aggregate Q non-examples (False Q) grouped by parameters
            # q_false_qvals_by_params = {}
            # ... (old code for q_false_qvals_by_params) ...

            # # 4. Find parameter sets where P is true and Q is false for the *same* quantified values
            # for params, p_true_qvals in p_true_qvals_by_params.items():
            # ... (old code for comparison loop) ...

    def _build_computational_impl(
        self,
        concept: Concept,
        indices: List[int],
        kept_indices: List[int],
    ) -> callable:
        """Build computational implementation for the single-predicate concept."""
        num_vars = len(indices)

        def impl(*args):
            # Prepare arguments for compute call, with placeholders for quantified vars
            full_args = [None] * concept.examples.example_structure.input_arity

            # Populate unquantified (kept) arguments
            for i, kept_idx in enumerate(kept_indices):
                full_args[kept_idx] = args[i]

            # Use random sampling directly
            # ASSUMPTION: Quantified variables are numeric/sampleable in this range.
            for _ in range(MAX_SEARCH_LIMIT):
                try:
                    # Generate a random combination of values for quantified variables
                    values = tuple(
                        random.randint(0, MAX_VALUE_PER_VAR) for _ in range(num_vars)
                    )
                    # Assign values to quantified variables
                    for idx_pos, idx in enumerate(indices):
                        full_args[idx] = values[idx_pos]
                except TypeError:
                    # Cannot sample non-numeric types this way
                    # TODO: Implement proper sampling based on domain type
                    print(f"Warning: Cannot perform random sampling for non-numeric quantified variables in {concept.name}. Returning True (optimistic).")
                    return True # Cannot falsify, return True

                # Check if this combination satisfies the predicate
                try:
                    if not concept.compute(*full_args):
                        # Found a counterexample
                        return False
                except Exception as e:
                    # Skip combinations that cause errors
                    if self.verbose:
                         print(f"Skipping combination due to error: {e}")
                    continue

            # No counterexample found among the sampled values
            return True

        return impl