"""
This module implements the Compose production rule for composing concepts.

The rule supports two types of composition:

1. Function Composition (g ∘ f):
   For functions f: X₁ × ... × Xₙ → Y₁ × ... × Yₘ and g: Z₁ × ... × Zₖ → W
   Given:
   - output_to_input_map: Dict[int, int] mapping f's output indices to g's input indices
   Creates h where:
   h(x₁,...,xₙ,p₁,...,pᵢ) = g(v₁,...,vₖ)
   where:
   - vⱼ = f(x₁,...,xₙ)ᵢ if j is mapped from output i in output_to_input_map
   - vⱼ = pᵢ if j is not in output_to_input_map.values() (these become parameters)

2. Predicate Composition:
   For predicates P: X₁ × ... × Xₙ → Bool and Q: Z₁ × ... × Zₖ → Bool
   Given same mapping as above, creates R where:
   R(x₁,...,xₙ,p₁,...,pᵢ) = P(x₁,...,xₙ) ∧ Q(v₁,...,vₖ)
   where vⱼ values follow same rules as function composition

3. Function to Predicate Composition:
   For functions f: X₁ × ... × Xₙ → Y₁ × ... × Yₘ and P: Z₁ × ... × Zₖ → Bool
   Given:
   - output_to_input_map: Dict[int, int] mapping f's output indices to P's input indices
   Creates h where:
   h(x₁,...,xₙ,p₁,...,pᵢ) = P(v₁,...,vₖ)
   where vⱼ values follow same rules as function composition

The mapping must satisfy:
- Each input index of g must appear exactly once in output_to_input_map values
  or be handled as a parameter (not in output_to_input_map.values())
"""
import random
from typing import List, Optional, Tuple, Any, Dict, Type, Union
from frame.productions.base import ProductionRule
from frame.knowledge_base.entities import (
    And,
    Entity,
    Concept,
    ExampleStructure,
    ConceptType,
)

from frame.tools.z3_template import Z3Template, _format_args

def validate_mapping(
    output_to_input_map: Dict[int, int], 
    inner_concept: Concept, 
    outer_concept: Concept
) -> bool:
    """
    Validate that a mapping for composition is valid.

    Args:
        output_to_input_map: Maps output indices of inner_concept to input indices of outer_concept
        inner_concept: The inner concept in the composition
        outer_concept: The outer concept in the composition

    Returns:
        bool: True if the mapping is valid, False otherwise
    """
    inner_in_arity = inner_concept.get_input_arity()
    inner_out_arity = len(inner_concept.get_component_types()) - inner_in_arity
    outer_in_arity = outer_concept.get_input_arity()

    # Check that output indices are valid
    if any(i >= inner_out_arity for i in output_to_input_map.keys()):
        return False

    # Check that input indices are valid
    if any(i >= outer_in_arity for i in output_to_input_map.values()):
        return False

    # Calculate parameter indices (inputs not mapped from outputs)
    param_inputs = set(
        i for i in range(outer_in_arity) if i not in output_to_input_map.values()
    )

    # Check that each input is used exactly once
    used_inputs = set(output_to_input_map.values()).union(param_inputs)
    if len(used_inputs) != outer_in_arity:
        return False

    # Check type compatibility for mapped outputs to inputs
    inner_types = inner_concept.get_component_types()
    outer_types = outer_concept.get_component_types()

    for out_idx, in_idx in output_to_input_map.items():
        inner_out_type = inner_types[inner_in_arity + out_idx]
        outer_in_type = outer_types[in_idx]
        if inner_out_type != outer_in_type:
            return False

    return True


class ComposeRule(ProductionRule):
    """
    Production rule that composes two concepts to create a new concept.

    Supports three types of composition:
    1. Function-function: g(f(...))
    2. Predicate-predicate: P(x) ∧ Q(y) where some variables are shared
    3. Function-predicate: P(f(x))
    """

    def __init__(self, verbose=False):
        """
        Initialize the compose rule.

        Args:
            verbose: Whether to print detailed debug information
        """
        super().__init__(
            name="compose",
            description="Composes two concepts to create a new concept",
            type="Concept",
        )

    def determine_verification_capabilities(self, *inputs: Entity) -> Tuple[bool, bool]:
        """
        Determine verification capabilities for composition.
        
        For composition, verification capabilities depend on the specific composition type:
        1. Function-function composition: Relies on both functions' capabilities
        2. Predicate-predicate composition: Requires both predicates to be verifiable
        3. Function-predicate composition: Depends on both function and predicate capabilities
        
        Returns:
            Tuple[bool, bool]: (can_add_examples, can_add_nonexamples)
        """
        if len(inputs) != 2:
            return True, True
            
        concept1, concept2 = inputs
        type1 = concept1.examples.example_structure.concept_type
        type2 = concept2.examples.example_structure.concept_type
        
        # Get capabilities of both input concepts
        c1_can_add_examples = concept1.can_add_examples
        c1_can_add_nonexamples = concept1.can_add_nonexamples
        c2_can_add_examples = concept2.can_add_examples
        c2_can_add_nonexamples = concept2.can_add_nonexamples
        
        # Determine composition type and calculate capabilities
        if type1 == type2 == ConceptType.FUNCTION:
            # For function composition g(f(x)), we need both functions to be reliable
            # in the same way for the composition to be reliable
            return (c1_can_add_examples and c2_can_add_examples, 
                    c1_can_add_nonexamples and c2_can_add_nonexamples)
                    
        elif type1 == type2 == ConceptType.PREDICATE:
            # For predicate composition (P ∧ Q), we need:
            # - Both predicates to verify examples (both need to be true)
            # - Only one predicate to verify non-examples (if either is false, the conjunction is false)
            can_add_examples = c1_can_add_examples and c2_can_add_examples
            can_add_nonexamples = c1_can_add_nonexamples or c2_can_add_nonexamples
            return can_add_examples, can_add_nonexamples
            
        else:  # Function to predicate
            # For P(f(x)), we need:
            # - Function to reliably compute outputs and predicate to verify examples
            # - Function to reliably compute outputs and predicate to verify non-examples
            return (c1_can_add_examples and c2_can_add_examples,
                    c1_can_add_nonexamples and c2_can_add_nonexamples)

    def get_input_types(
        self,
    ) -> List[
        List[Tuple[Type[Entity], Optional[Union[ConceptType, List[ConceptType]]]]]
    ]:
        """
        Get the types of inputs this rule expects.

        Returns:
            List[List[Tuple[Type[Entity], Optional[Union[ConceptType, List[ConceptType]]]]]]:
                A list of alternative input type specifications
        """
        return [
            [(Concept, ConceptType.FUNCTION), (Concept, ConceptType.FUNCTION)],
            [(Concept, ConceptType.PREDICATE), (Concept, ConceptType.PREDICATE)],
            [(Concept, ConceptType.FUNCTION), (Concept, ConceptType.PREDICATE)],
        ]

    def _validate_function_composition(
        self,
        inner: Concept,
        outer: Concept,
        output_to_input_map: Dict[int, int],
        verbose: bool = False,
    ) -> bool:
        """
        Validate function composition.

        Args:
            inner: Inner function in the composition
            outer: Outer function in the composition
            output_to_input_map: Maps output indices of inner to input indices of outer
            verbose: Whether to print debug information

        Returns:
            bool: True if the composition is valid, False otherwise
        """
        inner_in_arity = inner.get_input_arity()
        inner_out_arity = len(inner.get_component_types()) - inner_in_arity

        # Check that we have a valid mapping
        if not output_to_input_map:
            if verbose:
                print("❌ Failed: Function composition requires output_to_input_map")
            return False

        # Validate the mapping
        if not validate_mapping(output_to_input_map, inner, outer):
            if verbose:
                print("❌ Failed: Invalid mapping for function composition")
            return False

        if verbose:
            print(f"✓ Valid function composition with mapping: {output_to_input_map}")
        return True

    def _validate_predicate_composition(
        self,
        pred1: Concept,
        pred2: Concept,
        shared_vars: Dict[int, int],
        verbose: bool = False,
    ) -> bool:
        """
        Validate predicate composition.

        Args:
            pred1: First predicate in the composition
            pred2: Second predicate in the composition
            shared_vars: Maps variable indices between predicates
            verbose: Whether to print debug information

        Returns:
            bool: True if the composition is valid, False otherwise
        """
        if not shared_vars:
            if verbose:
                print("❌ Failed: Predicate composition requires shared_vars")
            return False

        # Check that shared variable indices are valid
        arity1 = pred1.get_input_arity()
        arity2 = pred2.get_input_arity()

        if any(i >= arity1 for i in shared_vars.keys()) or any(
            i >= arity2 for i in shared_vars.values()
        ):
            if verbose:
                print("❌ Failed: Invalid shared variable indices")
            return False

        if verbose:
            print(f"✓ Valid predicate composition with shared vars: {shared_vars}")
        return True

    def _validate_function_to_predicate(
        self,
        func: Concept,
        pred: Concept,
        output_to_input_map: Dict[int, int],
        verbose: bool = False,
    ) -> bool:
        """
        Validate function-to-predicate composition.

        Args:
            func: Function in the composition
            pred: Predicate in the composition
            output_to_input_map: Maps output indices of function to input indices of predicate
            verbose: Whether to print debug information

        Returns:
            bool: True if the composition is valid, False otherwise
        """
        if not output_to_input_map:
            if verbose:
                print(
                    "❌ Failed: Function-to-predicate composition requires output_to_input_map"
                )
            return False

        # Validate the mapping
        if not validate_mapping(output_to_input_map, func, pred):
            if verbose:
                print(
                    "❌ Failed: Invalid mapping for function-to-predicate composition"
                )
            return False

        if verbose:
            print(
                f"✓ Valid function-to-predicate composition with mapping: {output_to_input_map}"
            )
        return True

    def get_valid_parameterizations(self, *inputs: Entity) -> List[Dict[str, Any]]:
        """
        Get valid parameterizations for composing the given concepts.

        Args:
            *inputs: Two concepts to compose

        Returns:
            List of valid parameter dictionaries
        """
        # Check if the inputs match the expected input types
        if not self.check_input_types(*inputs):
            return []

        valid_parameterizations = []

        concept1, concept2 = inputs
        type1 = concept1.examples.example_structure.concept_type
        type2 = concept2.examples.example_structure.concept_type

        # Case 1: Function Composition (both are functions)
        if type1 == type2 == ConceptType.FUNCTION:
            inner_concept = concept1
            outer_concept = concept2

            inner_in_arity = inner_concept.get_input_arity()
            inner_out_arity = len(inner_concept.get_component_types()) - inner_in_arity
            outer_in_arity = outer_concept.get_input_arity()

            inner_types = inner_concept.get_component_types()
            outer_types = outer_concept.get_component_types()

            # Generate all possible output_to_input_map dictionaries
            # We need to map each output of inner function to an input of outer function
            # For simplicity, we'll generate mappings where each output is used exactly once

            # First, find which outputs can be mapped to which inputs based on type compatibility
            compatible_mappings = {}
            for out_idx in range(inner_out_arity):
                compatible_mappings[out_idx] = []
                inner_out_type = inner_types[inner_in_arity + out_idx]
                for in_idx in range(outer_in_arity):
                    outer_in_type = outer_types[in_idx]
                    if inner_out_type == outer_in_type:
                        compatible_mappings[out_idx].append(in_idx)

            # Generate all possible combinations of mappings
            # This is a recursive function to build all valid mappings
            def generate_mappings(out_idx, current_map, used_inputs):
                if out_idx == inner_out_arity:
                    # We've mapped all outputs
                    return [{"output_to_input_map": current_map.copy()}]

                results = []
                for in_idx in compatible_mappings[out_idx]:
                    if (
                        in_idx not in used_inputs
                    ):  # Ensure each input is used at most once
                        new_map = current_map.copy()
                        new_map[out_idx] = in_idx
                        new_used = used_inputs.copy()
                        new_used.add(in_idx)
                        results.extend(
                            generate_mappings(out_idx + 1, new_map, new_used)
                        )
                return results

            # Start the recursive generation with empty map and no used inputs
            all_mappings = generate_mappings(0, {}, set())

            # Validate each mapping
            for mapping in all_mappings:
                if validate_mapping(
                    mapping["output_to_input_map"], inner_concept, outer_concept
                ):
                    valid_parameterizations.append(mapping)

        # Case 2: Predicate Composition (both are predicates)
        elif type1 == type2 == ConceptType.PREDICATE:
            pred1 = concept1
            pred2 = concept2

            arity1 = pred1.get_input_arity()
            arity2 = pred2.get_input_arity()

            # Generate all possible shared_vars dictionaries
            # We'll try all possible ways to share variables between the predicates

            # For simplicity, we'll generate mappings where each variable is shared at most once
            # This is a recursive function to build all valid mappings
            def generate_mappings(var1_idx, current_map, used_vars2):
                if var1_idx == arity1:
                    # We've mapped all variables from pred1
                    return [{"shared_vars": current_map.copy()}]

                # Option 1: Don't share this variable
                results = generate_mappings(
                    var1_idx + 1, current_map.copy(), used_vars2.copy()
                )

                # Option 2: Share with an unused variable from pred2
                for var2_idx in range(arity2):
                    if var2_idx not in used_vars2:
                        new_map = current_map.copy()
                        new_map[var1_idx] = var2_idx
                        new_used = used_vars2.copy()
                        new_used.add(var2_idx)
                        results.extend(
                            generate_mappings(var1_idx + 1, new_map, new_used)
                        )

                return results

            # Start the recursive generation with empty map and no used variables
            all_mappings = generate_mappings(0, {}, set())

            # Filter out empty mappings (must share at least one variable)
            all_mappings = [m for m in all_mappings if m["shared_vars"]]

            # Add all valid mappings
            valid_parameterizations.extend(all_mappings)

        # Case 3: Function-to-Predicate Composition
        elif type1 == ConceptType.FUNCTION and type2 == ConceptType.PREDICATE:
            func = concept1
            pred = concept2

            func_in_arity = func.get_input_arity()
            func_out_arity = len(func.get_component_types()) - func_in_arity
            pred_in_arity = pred.get_input_arity()

            func_types = func.get_component_types()
            pred_types = pred.get_component_types()

            # Generate all possible output_to_input_map dictionaries
            # We need to map some outputs of the function to inputs of the predicate

            # First, find which outputs can be mapped to which inputs based on type compatibility
            compatible_mappings = {}
            for out_idx in range(func_out_arity):
                compatible_mappings[out_idx] = []
                func_out_type = func_types[func_in_arity + out_idx]
                for in_idx in range(pred_in_arity):
                    pred_in_type = pred_types[in_idx]
                    if func_out_type == pred_in_type:
                        compatible_mappings[out_idx].append(in_idx)

            # Generate all possible combinations of mappings
            # This is a recursive function to build all valid mappings
            def generate_mappings(out_idx, current_map, used_inputs):
                if out_idx == func_out_arity:
                    # We've mapped all outputs
                    return [{"output_to_input_map": current_map.copy()}]

                results = []

                # Option 1: Don't map this output
                results.extend(
                    generate_mappings(
                        out_idx + 1, current_map.copy(), used_inputs.copy()
                    )
                )

                # Option 2: Map to an unused input
                for in_idx in compatible_mappings[out_idx]:
                    if in_idx not in used_inputs:
                        new_map = current_map.copy()
                        new_map[out_idx] = in_idx
                        new_used = used_inputs.copy()
                        new_used.add(in_idx)
                        results.extend(
                            generate_mappings(out_idx + 1, new_map, new_used)
                        )

                return results

            # Start the recursive generation with empty map and no used inputs
            all_mappings = generate_mappings(0, {}, set())

            # Filter out empty mappings (must map at least one output)
            all_mappings = [m for m in all_mappings if m["output_to_input_map"]]

            # Validate each mapping
            for mapping in all_mappings:
                if validate_mapping(mapping["output_to_input_map"], func, pred):
                    valid_parameterizations.append(mapping)

        # Sort parameterizations by simplicity (fewer parameters is simpler)
        if (
            type1 == type2 == ConceptType.FUNCTION
            or type1 == ConceptType.FUNCTION
            and type2 == ConceptType.PREDICATE
        ):
            valid_parameterizations.sort(
                key=lambda p: len(
                    [
                        i
                        for i in range(concept2.examples.example_structure.input_arity)
                        if i not in p["output_to_input_map"].values()
                    ]
                )
            )
        elif type1 == type2 == ConceptType.PREDICATE:
            valid_parameterizations.sort(
                key=lambda p: len(p["shared_vars"]), reverse=True
            )

        return valid_parameterizations

    def can_apply(
        self,
        *inputs: Entity,
        output_to_input_map: Optional[Dict[int, int]] = None,
        shared_vars: Optional[Dict[int, int]] = None,
        verbose: bool = True,
    ) -> bool:
        """
        Check if composition can be applied to the inputs.

        Args:
            *inputs: Two concepts to compose
            output_to_input_map: For function composition, maps outputs to inputs
            shared_vars: For predicate composition, maps variables to be shared
            verbose: Whether to print debug information
        """
        if len(inputs) != 2 or not all(isinstance(x, Concept) for x in inputs):
            if verbose:
                print("❌ Failed: Must have exactly two inputs of type Concept")
            return False

        concept1, concept2 = inputs
        type1 = concept1.examples.example_structure.concept_type
        type2 = concept2.examples.example_structure.concept_type

        # Determine composition type and validate
        if type1 == type2 == ConceptType.FUNCTION:
            if shared_vars is not None:
                if verbose:
                    print(
                        "❌ Failed: Function composition requires output_to_input_map"
                    )
                return False
            return self._validate_function_composition(
                concept1, concept2, output_to_input_map, verbose
            )

        elif type1 == type2 == ConceptType.PREDICATE:
            if output_to_input_map is not None:
                if verbose:
                    print("❌ Failed: Predicate composition requires shared_vars")
                return False
            return self._validate_predicate_composition(
                concept1, concept2, shared_vars, verbose
            )

        elif type1 == ConceptType.FUNCTION and type2 == ConceptType.PREDICATE:
            if shared_vars is not None:
                if verbose:
                    print(
                        "❌ Failed: Function-to-predicate composition requires output_to_input_map"
                    )
                return False
            return self._validate_function_to_predicate(
                concept1, concept2, output_to_input_map, verbose
            )

        if verbose:
            print("❌ Failed: Invalid combination of concept types")
        return False

    def apply(
        self,
        *inputs: Entity,
        output_to_input_map: Dict[int, int] = None,
        shared_vars: Dict[int, int] = None,
    ) -> Entity:
        """
        Apply the composition rule to create a new concept.

        Args:
            *inputs: Two concepts to compose
            output_to_input_map: For function composition, maps outputs to inputs
            shared_vars: For predicate composition, maps variables to be shared

        Returns:
            Entity: The newly composed concept
        """
        if not self.can_apply(
            *inputs, output_to_input_map=output_to_input_map, shared_vars=shared_vars
        ):
            raise ValueError("Cannot apply Composition to these inputs")

        concept1, concept2 = inputs
        type1 = concept1.examples.example_structure.concept_type
        type2 = concept2.examples.example_structure.concept_type

        # Get verification capabilities for the composition
        can_add_examples, can_add_nonexamples = self.determine_verification_capabilities(*inputs)

        # Determine composition type and apply
        if type1 == type2 == ConceptType.FUNCTION:
            new_concept = self._compose_functions(concept1, concept2, output_to_input_map)
        elif type1 == type2 == ConceptType.PREDICATE:
            new_concept = self._compose_predicates(concept1, concept2, shared_vars)
        elif type1 == ConceptType.FUNCTION and type2 == ConceptType.PREDICATE:  # Function to predicate
            new_concept = self._compose_function_to_predicate(
                concept1, concept2, output_to_input_map
            )
        else:
            raise ValueError(f"Invalid combination of concept types: {type1} and {type2}")
        new_concept.map_iterate_depth = max(concept1.map_iterate_depth, concept2.map_iterate_depth)
        new_concept.can_add_examples = can_add_examples
        new_concept.can_add_nonexamples = can_add_nonexamples

        return new_concept

    def _compose_functions(
        self, inner: Concept, outer: Concept, output_to_input_map: Dict[int, int]
    ) -> Concept:
        """Implement function composition."""
        inner_in_arity = inner.get_input_arity()
        inner_out_arity = len(inner.get_component_types()) - inner_in_arity
        outer_in_arity = outer.get_input_arity()
        outer_out_arity = len(outer.get_component_types()) - outer_in_arity

        def _translation(*args, inner_function, outer_function):
            """Build function composition translation"""
            # Ensure we have enough arguments
            if len(args) < (inner_in_arity + outer_in_arity - len(output_to_input_map)):
                return None

            # Apply inner function
            inner_result = inner_function(*args[:inner_in_arity])
            if not isinstance(inner_result, tuple):
                inner_result = (inner_result,)

            # Build outer function arguments
            outer_args = [None] * outer_in_arity
            for out_idx, in_idx in output_to_input_map.items():
                if out_idx < len(inner_result):
                    outer_args[in_idx] = inner_result[out_idx]
                else:
                    # Output index out of range
                    return None

            # Fill remaining arguments from additional parameters
            param_start = inner_in_arity
            for i in range(outer_in_arity):
                if i not in output_to_input_map.values():
                    if param_start < len(args):
                        outer_args[i] = args[param_start]
                        param_start += 1
                    else:
                        # Not enough arguments provided
                        return None

            return outer_function(*outer_args)
        
        def symbolic_definition(*args):
            """Build function composition"""
            return _translation(*args, inner_function=inner.symbolic, outer_function=outer.symbolic)

        def compute(*args):
            """Compute composition"""
            try:
                return _translation(*args, inner_function=inner.compute, outer_function=outer.compute)
            except (TypeError, AttributeError, IndexError):
                # Handle case where compute method exists but isn't callable
                # or there's an index error
                return None
            
        def _z3_translate_compose_functions(*args):
            inner_template = inner.to_z3(*([None] * inner.get_input_arity()))
            inner_program = inner_template.program

            outer_template = outer.to_z3(*([None] * outer.get_input_arity()))
            outer_program = outer_template.program

            code = f"""
            params {inner_program.params + outer_program.params - len(output_to_input_map)};
            bounded params 0;   
            """
            
            code += f"""
            f_0 := Func(
                {inner_program.dsl()}
            );
            f_1 := Func(
                {outer_program.dsl()}
            );
            """ 

            inner_input_arity = int(inner_program.params)
            outer_input_arity = int(outer_program.params)
            
            inner_args_string = f"f_0({_format_args([f'x_{i}' for i in range(inner_input_arity)])})" 
            
            # Build outer function arguments
            outer_args = [None] * outer_input_arity
            for out_idx, in_idx in output_to_input_map.items():
                outer_args[in_idx] = inner_args_string


            # Fill remaining arguments from additional parameters
            param_start = inner_input_arity
            for i in range(outer_input_arity):
                if i not in output_to_input_map.values():
                    if param_start < len(args):
                        outer_args[i] = f"x_{param_start}"
                        param_start += 1
                    else:
                        # Not enough arguments provided
                        return None
            
            outer_args_string = f"f_1({_format_args(outer_args)})"

            code += f"""
            ReturnExpr {outer_args_string};
            ReturnPred None;
            """

            template = Z3Template(code)
            template.set_args(*args)
            return template

        # Calculate new types
        inner_types = inner.examples.example_structure.component_types
        outer_types = outer.examples.example_structure.component_types

        # Input types are inner function inputs plus parameters for unmapped outer inputs
        input_types = list(inner_types[:inner_in_arity])
        for i in range(outer_in_arity):
            if i not in output_to_input_map.values():
                input_types.append(outer_types[i])

        # Output types come from outer function
        output_types = list(outer_types[outer_in_arity:])

        new_concept = Concept(
            name=f"compose_({inner.name}_with_{outer.name}_output_to_input_map={output_to_input_map})",
            description=f"Function composition of {inner.name} and {outer.name}",
            symbolic_definition=symbolic_definition,
            computational_implementation=compute
                if inner.has_computational_implementation()
                and outer.has_computational_implementation()
                else None,
            example_structure=ExampleStructure(
                concept_type=ConceptType.FUNCTION,
                component_types=tuple(input_types + output_types),
                input_arity=len(input_types),
            ),
            z3_translation=
                _z3_translate_compose_functions
                if inner_out_arity == 1 and inner.has_z3_translation() and outer.has_z3_translation()
                else None,
        )

        self._transform_function_examples(
            new_concept, outer, inner, output_to_input_map
        )
        return new_concept

    def _compose_predicates(
        self, pred1: Concept, pred2: Concept, shared_vars: Dict[int, int]
    ) -> Concept:
        """Implement predicate composition."""
        arity1 = pred1.get_input_arity()
        arity2 = pred2.get_input_arity()

        def _translation(*args, pred1_function, pred2_function, shared_vars): 
            """Build conjunction of predicates with shared variables"""
            # Similar logic to symbolic_definition
            args1 = [None] * arity1
            args2 = [None] * arity2

            # Ensure we have enough arguments
            if len(args) < (arity1 + arity2 - len(shared_vars)):
                return None

            # assign args to pred1 in original order
            for i in range(arity1):
                args1[i] = args[i]

            # Map shared variables from pred1 to pred2
            for idx1, idx2 in shared_vars.items():
                args2[idx2] = args1[idx1]

            # Track which pred2 arguments are already filled
            filled_args2 = set(shared_vars.values())
            
            # Fill remaining arguments for pred2 from the remaining input args
            arg_idx = arity1
            for i in range(arity2):
                if i not in filled_args2:
                    args2[i] = args[arg_idx]
                    arg_idx += 1

            # Safely compute both predicates
            try:
                result1 = pred1_function(*args1)
                result2 = pred2_function(*args2)
                return result1, result2
            except (TypeError, AttributeError):
                # Handle case where method exists but isn't callable
                # or returns a non-callable value
                return None

        def symbolic_definition(*args):
            """Build conjunction of predicates with shared variables"""
            result1, result2 = _translation(*args, 
                                pred1_function=pred1.symbolic, 
                                pred2_function=pred2.symbolic, 
                                shared_vars=shared_vars)

            return And(result1, result2)

        def compute(*args):
            """Compute conjunction of predicates"""
            result1, result2 = _translation(*args, 
                                pred1_function=pred1.compute, 
                                pred2_function=pred2.compute, 
                                shared_vars=shared_vars)

            return result1 and result2
        
        def _z3_translate_compose_predicates(*args):
            pred1_template = pred1.to_z3(*([None] * pred1.get_input_arity()))
            pred1_program = pred1_template.program

            pred2_template = pred2.to_z3(*([None] * pred2.get_input_arity()))
            pred2_program = pred2_template.program

            code = f"""
            params {pred1_program.params + pred2_program.params - len(shared_vars)};
            bounded params 0;   
            """
            
            code += f"""
            p_0 := Pred(
                {pred1_program.dsl()}
            );
            p_1 := Pred(
                {pred2_program.dsl()}
            );
            """ 

            pred1_input_arity = int(pred1_program.params)
            pred2_input_arity = int(pred2_program.params)

            args1 = [None] * pred1_input_arity
            args2 = [None] * pred2_input_arity
            
            used_args = set()
            arg_idx = 0

            # Fill shared variables
            for idx1, idx2 in shared_vars.items():
                args1[idx1] = f"x_{arg_idx}"
                args2[idx2] = f"x_{arg_idx}"
                used_args.add(idx1)
                arg_idx += 1

            # Fill remaining variables for pred1
            for i in range(pred1_input_arity):
                if i not in used_args:
                    args1[i] = f"x_{arg_idx}"
                    arg_idx += 1

            # Fill remaining variables for pred2
            used_args = set(shared_vars.values())
            for i in range(pred2_input_arity):
                if i not in used_args:
                    args2[i] = f"x_{arg_idx}"
                    arg_idx += 1
            
            pred1_args_string = f"p_0({_format_args(args1)})"
            pred2_args_string = f"p_1({_format_args(args2)})"

            code += f"""
            ReturnExpr None;
            ReturnPred And({pred1_args_string}, {pred2_args_string});
            """

            template = Z3Template(code)
            template.set_args(*args)
            return template

        # Calculate new types
        types1 = pred1.examples.example_structure.component_types
        types2 = pred2.examples.example_structure.component_types

        # New arity is sum of arities minus shared variables
        new_arity = arity1 + arity2 - len(shared_vars)

        # Collect types in order: shared vars, remaining pred1 vars, remaining pred2 vars
        new_types = []
        used_types1 = set()

        # First add types for shared variables
        for idx1, _ in shared_vars.items():
            new_types.append(types1[idx1])
            used_types1.add(idx1)

        # Add remaining types from pred1
        for i in range(arity1):
            if i not in used_types1:
                new_types.append(types1[i])

        # Add remaining types from pred2
        used_types2 = set(shared_vars.values())
        for i in range(arity2):
            if i not in used_types2:
                new_types.append(types2[i])

        new_concept = Concept(
            name=f"compose_({pred1.name}_with_{pred2.name}_shared_vars={shared_vars})",
            description=f"Predicate composition of {pred1.name} and {pred2.name}",
            symbolic_definition=symbolic_definition,
            computational_implementation=compute
                if pred1.has_computational_implementation()
                and pred2.has_computational_implementation()
                else None,
            example_structure=ExampleStructure(
                concept_type=ConceptType.PREDICATE,
                component_types=tuple(new_types),
                input_arity=new_arity,
            ),
            z3_translation=
                _z3_translate_compose_predicates
                if pred1.has_z3_translation() and pred2.has_z3_translation()
                else None,
        )

        self._transform_predicate_examples(new_concept, pred1, pred2, shared_vars)
        return new_concept

    def _compose_function_to_predicate(
        self, func: Concept, pred: Concept, output_to_input_map: Dict[int, int]
    ) -> Concept:
        """Implement function-to-predicate composition."""
        func_in_arity = func.get_input_arity()
        pred_in_arity = pred.get_input_arity()

        def _translation(*args, func_function, pred_function):
            """Translation for composition of function to predicate"""
            # Ensure we have enough arguments
            if len(args) < (func_in_arity + pred_in_arity - len(output_to_input_map)):
                return None

            # Apply function
            func_result = func_function(*args[:func_in_arity])
            if not isinstance(func_result, tuple):
                func_result = (func_result,)

            # Build predicate arguments
            pred_args = [None] * pred_in_arity
            for out_idx, in_idx in output_to_input_map.items():
                if out_idx < len(func_result):
                    pred_args[in_idx] = func_result[out_idx]
                else:
                    # Output index out of range
                    return None

            # Fill remaining arguments from additional parameters
            param_start = func_in_arity
            for i in range(pred_in_arity):
                if i not in output_to_input_map.values():
                    if param_start < len(args):
                        pred_args[i] = args[param_start]
                        param_start += 1
                    else:
                        # Not enough arguments provided
                        return None

            return pred_function(*pred_args)

        
        def symbolic_definition(*args):
            """Build predicate from function application"""
            return _translation(*args, func_function=func.symbolic, pred_function=pred.symbolic)

        def compute(*args):
            """Compute by applying function then predicate"""
            return _translation(*args, func_function=func.compute, pred_function=pred.compute)

        def _z3_translate_compose_function_to_predicate(*args):
            func_template = func.to_z3(*([None] * func.get_input_arity()))
            func_program = func_template.program

            pred_template = pred.to_z3(*([None] * pred.get_input_arity()))
            pred_program = pred_template.program

            code = f"""
            params {func_program.params + pred_program.params - len(output_to_input_map)};
            bounded params 0;   
            """
            
            code += f"""
            f_0 := Func(
                {func_program.dsl()}
            );
            p_0 := Pred(
                {pred_program.dsl()}
            );
            """ 

            func_input_arity = int(func_program.params)
            pred_input_arity = int(pred_program.params)
            
            func_args_string = f"f_0({_format_args([f'x_{i}' for i in range(func_input_arity)])})" 
            
            # Build outer function arguments
            pred_args = [None] * pred_input_arity
            for out_idx, in_idx in output_to_input_map.items():
                pred_args[in_idx] = func_args_string


            # Fill remaining arguments from additional parameters
            param_start = func_input_arity
            for i in range(pred_input_arity):
                if i not in output_to_input_map.values():
                    if param_start < len(args):
                        pred_args[i] = f"x_{param_start}"
                        param_start += 1
                    else:
                        # Not enough arguments provided
                        return None
            
            pred_args_string = f"p_0({_format_args(pred_args)})"

            code += f"""
            ReturnExpr None;
            ReturnPred {pred_args_string};
            """

            template = Z3Template(code)
            template.set_args(*args)
            return template

        # Calculate new types
        func_types = func.get_component_types()
        pred_types = pred.get_component_types()

        # Input types are function inputs plus parameters for unmapped predicate inputs
        input_types = list(func_types[:func_in_arity])
        for i in range(pred_in_arity):
            if i not in output_to_input_map.values():
                input_types.append(pred_types[i])

        new_concept = Concept(
            name=f"compose_({func.name}_with_{pred.name}_output_to_input_map={output_to_input_map})",
            description=f"Function-to-predicate composition of {func.name} and {pred.name}",
            symbolic_definition=symbolic_definition,
            computational_implementation=compute
                if func.has_computational_implementation()
                and pred.has_computational_implementation()
                else None,
            example_structure=ExampleStructure(
                concept_type=ConceptType.PREDICATE,
                component_types=tuple(input_types),
                input_arity=len(input_types),
            ),
            z3_translation=
                _z3_translate_compose_function_to_predicate
                if func.has_z3_translation() and pred.has_z3_translation()
                else None,
        )

        self._transform_function_to_predicate_examples(
            new_concept, func, pred, output_to_input_map
        )
        return new_concept

    def _transform_function_examples(
        self,
        new_concept: Entity,
        outer_concept: Entity,
        inner_concept: Entity,
        output_to_input_map: Dict[int, int],
    ):
        """
        Transform examples from the input concepts to examples for the composed function.

        Args:
            new_concept: The newly created composed concept
            outer_concept: The outer function in the composition
            inner_concept: The inner function in the composition
            output_to_input_map: Maps output indices of inner to input indices of outer
        """
        # Get examples from both concepts
        inner_examples = inner_concept.examples.get_examples()
        outer_examples = outer_concept.examples.get_examples()

        inner_in_arity = inner_concept.get_input_arity()
        outer_in_arity = outer_concept.get_input_arity()

        # For each inner example, try to find matching outer examples
        for inner_ex in inner_examples:
            inner_inputs = inner_ex.value[:inner_in_arity]
            inner_outputs = inner_ex.value[inner_in_arity:]

            # Try to match with outer examples
            for outer_ex in outer_examples:
                outer_inputs = outer_ex.value[:outer_in_arity]
                outer_outputs = outer_ex.value[outer_in_arity:]

                # Check if inner outputs match mapped outer inputs
                match = True
                param_values = []

                for out_idx, in_idx in output_to_input_map.items():
                    if (
                        out_idx >= len(inner_outputs)
                        or outer_inputs[in_idx] != inner_outputs[out_idx]
                    ):
                        match = False
                        break

                if match:
                    # Collect parameter values (unmapped outer inputs)
                    for i in range(outer_in_arity):
                        if i not in output_to_input_map.values():
                            param_values.append(outer_inputs[i])

                    # Create new example
                    new_inputs = inner_inputs + tuple(param_values)
                    new_example = new_inputs + outer_outputs

                    try:
                        new_concept.add_example(new_example)

                        # modify the output values to have at least one difference
                        new_concept_output_arity = len(new_concept.examples.example_structure.component_types) - new_concept.get_input_arity()
                        # pick a random output index
                        output_index = random.randint(0, new_concept_output_arity - 1)
                        # modify the value at the output index
                        modified_output_value = new_example[output_index]
                        modified_output_value = random.randint(0, 10)
                        new_nonexample = new_example[:output_index] + (modified_output_value,) + new_example[output_index + 1:]
                        if modified_output_value != new_example[output_index]:
                            new_concept.add_nonexample(new_nonexample)
                    except ValueError:
                        continue

    def _transform_predicate_examples(
        self,
        new_concept: Entity,
        pred1: Entity,
        pred2: Entity,
        shared_vars: Dict[int, int],
    ):
        """Transform examples based on predicate composition mapping."""
        # Get examples and nonexamples from both predicates
        pred1_examples = pred1.examples.get_examples()
        pred2_examples = pred2.examples.get_examples()
        pred1_nonexamples = pred1.examples.get_nonexamples()
        pred2_nonexamples = pred2.examples.get_nonexamples()

        # Get arities
        arity1 = pred1.get_input_arity()
        arity2 = pred2.get_input_arity()

        def try_add_value(value, is_example):
            """Helper to try adding a value as an example or nonexample."""
            try:
                if is_example:
                    new_concept.add_example(value)
                else:
                    new_concept.add_nonexample(value)
            except (ValueError, TypeError):
                pass

        def process_value_pair(val1, val2):
            """Helper to process a pair of values from the predicates."""
            try:
                if not isinstance(val1.value, tuple) or not isinstance(val2.value, tuple):
                    return

                pred1_args = val1.value[:arity1]
                pred2_args = val2.value[:arity2]

                # Check if values match on shared variables
                for idx1, idx2 in shared_vars.items():
                    if pred1_args[idx1] != pred2_args[idx2]:
                        return

                # For shared variables, we only need to use the value once
                value = (pred1_args[0],)  # Since we know both predicates share this value

                # For AND composition:
                # - It's an example only if both predicates return True
                # - It's a nonexample if either predicate returns False
                computed1 = pred1.compute(*pred1_args)
                computed2 = pred2.compute(*pred2_args)

                try_add_value(value, computed1 and computed2)
            except (ValueError, TypeError):
                pass

        # Process all combinations of examples and nonexamples
        # Case 1: Both examples
        for ex1 in pred1_examples:
            for ex2 in pred2_examples:
                process_value_pair(ex1, ex2)

        # Case 2: First example, second nonexample
        for ex1 in pred1_examples:
            for nonex2 in pred2_nonexamples:
                process_value_pair(ex1, nonex2)

        # Case 3: First nonexample, second example
        for nonex1 in pred1_nonexamples:
            for ex2 in pred2_examples:
                process_value_pair(nonex1, ex2)

        # Case 4: Both nonexamples
        for nonex1 in pred1_nonexamples:
            for nonex2 in pred2_nonexamples:
                process_value_pair(nonex1, nonex2)

        # Additionally, process individual examples/nonexamples
        # This ensures we catch cases that might be missed in the pairing
        all_values = set()
        for ex in pred1_examples:
            all_values.add(ex.value[0])
        for ex in pred2_examples:
            all_values.add(ex.value[0])
        for nonex in pred1_nonexamples:
            all_values.add(nonex.value[0])
        for nonex in pred2_nonexamples:
            all_values.add(nonex.value[0])

        # Try each value individually
        for val in all_values:
            value = (val,)
            try:
                computed = new_concept.compute(*value)
                try_add_value(value, computed)
            except (ValueError, TypeError):
                pass

    def _transform_function_to_predicate_examples(
        self,
        new_concept: Entity,
        func: Entity,
        pred: Entity,
        output_to_input_map: Dict[int, int],
    ):
        """Transform examples based on function-to-predicate composition mapping."""
        # Get arities
        func_in_arity = func.get_input_arity()

        # For each input value, compute the function result and check if it satisfies the predicate
        # We'll use a simple approach: test a range of values and add them as examples/nonexamples
        test_values = list(range(10))  # Test values 0-9

        for val in test_values:
            try:
                # Compute function result
                func_result = func.compute(val)

                # Check if function result satisfies predicate
                pred_result = pred.compute(func_result)

                # Add as example or nonexample
                if pred_result:
                    new_concept.add_example((val,))
                else:
                    new_concept.add_nonexample((val,))
            except (ValueError, TypeError):
                continue

        # If we still don't have examples, try to derive them from existing examples
        if not new_concept.examples.get_examples():
            # Get examples from function
            func_examples = func.examples.get_examples()

            for func_ex in func_examples:
                if not isinstance(func_ex.value, tuple):
                    continue

                # Get function input and output
                if len(func_ex.value) < func_in_arity + 1:
                    continue

                func_input = func_ex.value[:func_in_arity]
                func_output = func_ex.value[func_in_arity:]

                # Check if output satisfies predicate
                try:
                    # For simplicity, we'll just check the first output mapped to the predicate
                    out_idx = list(output_to_input_map.keys())[0]
                    in_idx = output_to_input_map[out_idx]

                    # Get the output value that maps to the predicate input
                    if out_idx >= len(func_output):
                        continue

                    output_val = func_output[out_idx]

                    # Check if this value satisfies the predicate
                    pred_result = pred.compute(output_val)

                    # Add as example or nonexample
                    if pred_result:
                        new_concept.add_example(func_input)
                    else:
                        new_concept.add_nonexample(func_input)
                except (ValueError, TypeError, IndexError):
                    continue
