"""
This module implements the Size production rule for sets and predicates.
"""

from typing import Optional, Tuple, List, Set, Any, Dict, Type, Union
from frame.productions.base import ProductionRule
from frame.knowledge_base.entities import (
    ConceptApplication,
    GroupDomain,
    NatDomain,
    SetDomain,
    GroupElementDomain,
    Entity,
    Concept,
    ExampleType,
    ExampleStructure,
    ConceptType,
    Set,
    SetCardinality,
    Lambda,
    TupleDomain,
)
from itertools import combinations


class SizeRule(ProductionRule):
    """
    Production rule that creates a new concept representing the cardinality (size) of a set or a predicate-defined set.
    
    Inputs either:

    1. A set concept – simply returning SetCardinality(ConceptApplication(concept, ...))

    2. A predicate concept – by quantifying over one or more arguments.
       For a predicate P(x₁, …, xₙ) and a chosen list of indices,
       the new concept returns:

         SetCardinality({ (xᵢ₁, …, xᵢₖ) : D | P(..., xᵢ₁, …, xᵢₖ, ...) })

       where the quantified variables replace the arguments at the specified indices.
    """

    def __init__(self):
        super().__init__(
            name="size",
            description=(
                "Creates a concept returning the size (cardinality) of a set or a "
                "predicate-defined set by quantifying over specified arguments"
            ),
            type="Concept",
        )

    def get_input_types(
        self,
    ) -> List[Tuple[Type[Entity], Optional[Union[ConceptType, List[ConceptType]]]]]:
        """
        Get the types of inputs this rule expects.

        Returns:
            List[Tuple[Type[Entity], Optional[Union[ConceptType, List[ConceptType]]]]]:
                A list containing a single tuple for the concept (set or predicate)
        """
        # This rule accepts either:
        # 1. A set concept (output type is SET)
        # 2. A predicate concept
        return [
            (Concept, [ConceptType.PREDICATE, ConceptType.FUNCTION])
        ]  # Function with SET output is also valid

    def get_valid_parameterizations(self, *inputs: Entity) -> List[Dict[str, Any]]:
        """
        Get valid parameterizations for applying the size rule to the given concept.

        Args:
            *inputs: A single concept (set or predicate)

        Returns:
            List of valid parameter dictionaries, each containing:
            - indices_to_quantify: List of indices to quantify over (for predicates)

        Note:
            For set concepts, returns a single empty dictionary (no parameters needed).
            For predicate concepts, returns dictionaries for all valid combinations of indices.
            If there are more than 100 valid parameterizations, returns a limited set.
        """
        # Check if we have the right number and types of inputs
        if len(inputs) != 1 or not isinstance(inputs[0], Concept):
            return []

        concept = inputs[0]
        comp_types = concept.get_component_types()
        input_arity = concept.get_input_arity()

        if (
            len(comp_types) == 0
        ):  # Requires at least one component type (otherwise this is not a valid concept)
            return []

        # Case 1: Set concept (output type is SET)
        if comp_types[-1] == ExampleType.SET:
            return [{}]  # No parameters needed for set concepts

        # Case 2: Predicate concept
        if concept.examples.example_structure.concept_type != ConceptType.PREDICATE:
            return []  # Not a predicate or set concept

        valid_parameterizations = []

        # Generate all possible combinations of indices to quantify
        # We need at least one index to quantify
        for size in range(1, input_arity + 1):
            for indices in combinations(range(input_arity), size):
                valid_parameterizations.append({"indices_to_quantify": list(indices)})

        # Limit the number of parameterizations to 100 if needed
        if len(valid_parameterizations) > 100:
            # Sort by the size of indices_to_quantify (smaller is better - more specific)
            valid_parameterizations.sort(key=lambda p: len(p["indices_to_quantify"]))
            valid_parameterizations = valid_parameterizations[:100]

        return valid_parameterizations

    def can_apply(
        self,
        *inputs: Entity,
        indices_to_quantify: Optional[List[int]] = None,
        verbose: bool = True,
    ) -> bool:
        # Expect exactly one input concept.
        if len(inputs) != 1 or not isinstance(inputs[0], Concept):
            if verbose:
                print("❌ Failed: Must have exactly one input of type Concept")
            return False

        concept = inputs[0]
        comp_types = concept.get_component_types()
        input_arity = concept.get_input_arity()

        # If the concept's output is a set, then indices_to_quantify should not be provided.
        if comp_types[-1] == ExampleType.SET:
            if indices_to_quantify is not None and len(indices_to_quantify) > 0:
                if verbose:
                    print(
                        "❌ Failed: For set concepts, do not specify indices to quantify over"
                    )
                return False
            return True

        # Otherwise, for predicate concepts, indices_to_quantify must be provided and non-empty.
        if concept.examples.example_structure.concept_type != ConceptType.PREDICATE:
            if verbose:
                print(
                    "❌ Failed: Concept must be either a set (output type SET) or a predicate"
                )
            return False

        if indices_to_quantify is None or len(indices_to_quantify) == 0:
            if verbose:
                print(
                    "❌ Failed: Must specify at least one index to quantify over for predicate concepts"
                )
            return False

        for idx in indices_to_quantify:
            if not (0 <= idx < input_arity):
                if verbose:
                    print(f"❌ Failed: Index {idx} to quantify is out of range")
                return False
        return True

    def determine_verification_capabilities(self, *inputs: Entity) -> Tuple[bool, bool]:
        """
        Determine verification capabilities for size operations.
        
        For size operations:
        - For examples: We generally can't verify the exact size of a set because that would 
          require a universal quantifier (to guarantee we've found all elements).
        - For nonexamples: We can disprove a size claim by finding more elements than the
          claimed size.
          
        Returns:
            Tuple[bool, bool]: (can_add_examples, can_add_nonexamples)
        """
        if not inputs:
            return False, True
            
        # Get capabilities of input concept
        concept = inputs[0]
        concept_can_add_examples = concept.can_add_examples
        
        # We can only verify nonexamples reliably; we return False for examples
        # since we generally can't verify the exact size of a set, but can know if the set is larger by direct computation.
        return False, concept_can_add_examples

    def apply(
        self,
        *inputs: Entity,
        indices_to_quantify: Optional[List[int]] = None,
        verbose: bool = False,
    ) -> Entity:
        if not self.can_apply(
            *inputs, indices_to_quantify=indices_to_quantify, verbose=verbose
        ):
            raise ValueError("Cannot apply SizeRule to these inputs")

        concept = inputs[0]
        comp_types = concept.get_component_types()
        input_arity = concept.get_input_arity()

        # Get verification capabilities
        can_add_examples, can_add_nonexamples = self.determine_verification_capabilities(*inputs)

        # --- Case 1: Set concept (output type is SET) ---
        if comp_types[-1] == ExampleType.SET:

            def size_compute(*args):
                s = concept.compute(*args)
                try:
                    return len(s)
                except Exception:
                    return None

            new_concept = Concept(
                name=f"size_of_({concept.name}_indices_{indices_to_quantify})",
                description=f"Cardinality of the set returned by {concept.name}",
                symbolic_definition=lambda *args: ConceptApplication(
                    SetCardinality, ConceptApplication(concept, *args)
                ),
                computational_implementation=size_compute,
                example_structure=ExampleStructure(
                    concept_type=ConceptType.FUNCTION,
                    component_types=comp_types[:-1] + (ExampleType.NUMERIC,),
                    input_arity=input_arity,
                ),
                lean4_translation=lambda *args: f"(Finset.card {concept.to_lean4(*args)})",
                prolog_translation=lambda *args: f"length({concept.to_prolog(*args)}, Result)",
                z3_translation=None,
                can_add_examples=can_add_examples,
                can_add_nonexamples=can_add_nonexamples,
            )
            new_concept.map_iterate_depth = concept.map_iterate_depth

            self._transform_examples_set(new_concept, concept)
            return new_concept

        # --- Case 2: Predicate concept ---
        # Free arguments are those indices not in indices_to_quantify.
        free_indices = [i for i in range(input_arity) if i not in indices_to_quantify]

        # Generate bound variable names for each quantified index.
        bound_vars = [f"x{i}" for i in range(len(indices_to_quantify))]

        # Determine default domain for each quantified variable based on its type.
        domains = []
        for i in indices_to_quantify:
            if comp_types[i] == ExampleType.NUMERIC:
                domains.append(NatDomain())
            elif comp_types[i] == ExampleType.SET:
                domains.append(SetDomain())
            elif comp_types[i] == ExampleType.GROUP:
                domains.append(GroupDomain())
            elif comp_types[i] == ExampleType.GROUPELEMENT:
                domains.append(GroupElementDomain())
            else:
                domains.append(NatDomain())

        # Overall domain for the quantified variables.
        if len(domains) == 1:
            default_domain = domains[0]
        else:
            default_domain = TupleDomain(tuple(domains))

        def symbolic_definition(*args):
            # args corresponds to free argument translations.
            free_args = list(args)
            full_args = []
            free_index = 0
            for i in range(input_arity):
                if i in indices_to_quantify:
                    # Insert the corresponding bound variable.
                    full_args.append(bound_vars[indices_to_quantify.index(i)])
                else:
                    full_args.append(free_args[free_index])
                    free_index += 1
            pred_expr = ConceptApplication(concept, *full_args)
            # Build nested lambda abstraction recursively over the bound variables.
            lam = pred_expr
            for var in reversed(bound_vars):
                lam = Lambda(var, lam)
            set_expr = Set(domain=default_domain, predicate=lam)
            return ConceptApplication(SetCardinality, set_expr)

        # New component types: free argument types plus NUMERIC output.
        new_comp_types = [comp_types[i] for i in free_indices] + [ExampleType.NUMERIC]

        new_concept = Concept(
            name=f"size_of_({concept.name}_indices_{indices_to_quantify})",
            description=(
                f"Cardinality of the set of values for arguments {indices_to_quantify} satisfying {concept.name}"
            ),
            symbolic_definition=symbolic_definition,
            computational_implementation=None,
            example_structure=ExampleStructure(
                concept_type=ConceptType.FUNCTION,
                component_types=tuple(new_comp_types),
                input_arity=len(free_indices),
            ),
            lean4_translation=lambda *args: (
                f"(Finset.card {{"
                f"{', '.join(bound_vars) if len(bound_vars) > 1 else bound_vars[0]}"
                f" : {default_domain.to_lean4()} | "
                f"{concept.to_lean4(*self._insert_quantified_placeholders(args, indices_to_quantify, input_arity, bound_vars))} }})"
            ),
            prolog_translation=lambda *args: (
                f"length({concept.to_prolog(*self._insert_quantified_placeholders(args, indices_to_quantify, input_arity, bound_vars))}, Result)"
            ),
            z3_translation=None,
            can_add_examples=can_add_examples,
            can_add_nonexamples=can_add_nonexamples,
        )
        new_concept.map_iterate_depth = concept.map_iterate_depth
        self._transform_examples_predicate(new_concept, concept, indices_to_quantify)
        return new_concept

    def _insert_quantified_placeholders(
        self,
        args: Tuple[str, ...],
        indices_to_quantify: List[int],
        input_arity: int,
        bound_vars: List[str],
    ) -> Tuple[str, ...]:
        """
        Given:
        - args: translations for the free arguments (length = input_arity - len(indices_to_quantify))
        - indices_to_quantify: list of indices where a bound variable should be inserted
        - bound_vars: the list of bound variable names (one for each index in indices_to_quantify)
        Returns a tuple of length input_arity with free arguments and bound variable placeholders inserted.
        """
        full = []
        free_index = 0
        bound_index = 0
        for i in range(input_arity):
            if i in indices_to_quantify:
                full.append(bound_vars[bound_index])
                bound_index += 1
            else:
                full.append(args[free_index])
                free_index += 1
        return tuple(full)

    def _transform_examples_set(self, new_concept: Entity, concept: Entity):
        for ex in concept.examples.get_examples():
            if not isinstance(ex.value, tuple):
                continue
            args = ex.value[:-1]
            set_val = ex.value[-1]
            try:
                size_val = len(set_val)
            except Exception:
                continue
            new_val = args + (size_val,)
            try:
                new_concept.add_example(new_val)
            except Exception:
                continue

        for ex in concept.examples.get_nonexamples():
            if not isinstance(ex.value, tuple):
                continue
            args = ex.value[:-1]
            set_val = ex.value[-1]
            try:
                size_val = len(set_val)
            except Exception:
                continue
            new_val = args + (size_val,)
            try:
                new_concept.add_nonexample(new_val)
            except Exception:
                continue

    def _transform_examples_predicate(
        self, new_concept: Entity, concept: Entity, indices_to_quantify: List[int]
    ):
        groups = {}
        input_arity = concept.get_input_arity()
        for ex in concept.examples.get_examples():
            if not isinstance(ex.value, tuple):
                continue
            key = tuple(
                ex.value[i] for i in range(input_arity) if i not in indices_to_quantify
            )
            quantified_vals = tuple(ex.value[i] for i in indices_to_quantify)
            groups.setdefault(key, set()).add(quantified_vals)

        for key, values in groups.items():
            new_val = key + (len(values),)
            try:
                new_concept.add_example(new_val)
            except Exception:
                continue

        groups_non = {}
        for ex in concept.examples.get_nonexamples():
            if not isinstance(ex.value, tuple):
                continue
            key = tuple(
                ex.value[i] for i in range(input_arity) if i not in indices_to_quantify
            )
            quantified_vals = tuple(ex.value[i] for i in indices_to_quantify)
            groups_non.setdefault(key, set()).add(quantified_vals)

        for key, values in groups_non.items():
            new_val = key + (len(values),)
            try:
                new_concept.add_nonexample(new_val)
            except Exception:
                continue