"""
This module implements the MapIterate production rule for iterating functions.
"""
import random
import time
import logging
import concurrent.futures
from concurrent.futures import TimeoutError as FuturesTimeoutError
from typing import List, Optional, Set, Tuple, Any, Union, Dict, Type
from frame.productions.base import ProductionRule
from frame.knowledge_base.entities import (
    Expression,
    Lambda,
    Var,
    ConceptApplication,
    Fold,
    Entity,
    Concept,
    Example,
    ExampleType,
    ExampleStructure,
    ConceptType,
)
from collections import defaultdict

# Default timeout for the internal iteration loop in seconds
DEFAULT_INTERNAL_MAP_ITERATE_TIMEOUT = 0.5
# Default timeout for computing *single* example during transformation
DEFAULT_DIRECT_COMPUTE_TIMEOUT = 0.05
# Default timeout for searching *one* chain in example transformation
DEFAULT_CHAIN_SEARCH_TIMEOUT = 0.05
# Maximum depth to follow chains during example transformation
MAX_CHAIN_DEPTH = 3
# --- New Total Timeouts for Transform --- #
# Max total time (seconds) allowed for the entire chain following section
TOTAL_CHAIN_TIMEOUT = 0.2
# Max total time (seconds) allowed for the entire direct computation section
TOTAL_COMPUTE_TIMEOUT = 0.2

logger = logging.getLogger(__name__)

class MapIterateRule(ProductionRule):
    """
    A production rule that applies a fold operation to a function concept, optionally with an accumulator concept.

    Inputs either:
    1. A unary function f: domain -> domain to create g(a,n) that applies f to a, n times
       Example: successor -> addition
    2. A binary function f: domain × domain -> domain and an accumulator concept to create
       g(a,n) that applies f(_, a) n times starting from the accumulator
       Example: (addition, zero) -> multiplication
               (multiplication, one) -> power
    """

    def __init__(self):
        super().__init__(
            name="map_iteration",
            description="Creates a new concept by folding/iterating a function",
            type="Concept",
        )

    def get_input_types(
        self,
    ) -> List[Tuple[Type[Entity], Optional[Union[ConceptType, List[ConceptType]]]]]:
        """
        Get the types of inputs this rule expects.

        Returns:
            List[List[Tuple[Type[Entity], Optional[Union[ConceptType, List[ConceptType]]]]]]:
                A list of alternative input type specifications, where each alternative is a list of tuples.
                This rule accepts either:
                - A single unary function concept
                - A binary function concept and a constant/function concept with arity 0
        """
        # This rule can accept either:
        # 1. A single unary function (input_arity=1)
        # 2. A binary function (input_arity=2) and a constant concept (input_arity=0)
        return [
            [(Concept, ConceptType.FUNCTION)],  # Option 1: Single unary function
            [
                (Concept, ConceptType.FUNCTION),
                (
                    Concept,
                    [ConceptType.FUNCTION, ConceptType.CONSTANT],
                ),  # Allow both FUNCTION and CONSTANT
            ],  # Option 2: Binary function + constant
        ]

    def get_valid_parameterizations(self, *inputs: Entity) -> List[Dict[str, Any]]:
        """MapIterateRule doesn't require specific parameterizations beyond input validation."""
        # Basic validation: ensure inputs are Concepts and meet type requirements
        if not self.can_apply(*inputs, verbose=False):
            return []
        
        # Check map_iterate_depth for the primary function concept (inputs[0]), if it's too deep, skip (heuristic for timeouts)
        primary_concept = inputs[0]
        if primary_concept.map_iterate_depth > 2:
            return []
            
        # If validation passes and depth is okay, it's a valid parameterization (with no specific params)
        return [{}] 

    def can_apply(self, *inputs: Entity, verbose: bool = True) -> bool:
        """
        Check if this rule can be applied to the inputs.

        Requirements for unary function:
        1. Single input concept
        2. Input concept must be a unary function
        3. Types must be numeric
        4. Output arity must be 1

        Requirements for binary function:
        1. Two input concepts: function and accumulator
        2. First input must be a binary function
        3. Second input must be a constant (arity 0)
        4. Types must be numeric
        5. Output arity must be 1

        Args:
            *inputs: Input entities to check
            verbose: Whether to print debug information
        """
        if verbose:
            print("Checking requirements:")

        # First check if inputs match any of our expected input type alternatives
        if not self.check_input_types(*inputs):
            if verbose:
                print("❌ Failed: Inputs don't match expected types")
            return False

        # Check if we have valid inputs
        if not inputs or not all(isinstance(x, Concept) for x in inputs):
            if verbose:
                print("❌ Failed: All inputs must be concepts")
            return False

        base_concept = inputs[0]
        if verbose:
            print(f"✓ First input is a concept: {base_concept.name}")

        # Check that the concept is a function
        if base_concept.examples.example_structure.concept_type != ConceptType.FUNCTION:
            if verbose:
                print("❌ Failed: First input concept must be a function")
            return False

        # Get input arity
        input_arity = base_concept.get_input_arity()

        # Unary function case
        if input_arity == 1:
            if len(inputs) != 1:
                if verbose:
                    print("❌ Failed: Unary function requires exactly one input")
                return False

            if verbose:
                print("✓ Valid unary function")
        # Binary function case
        elif input_arity == 2:
            if len(inputs) != 2:
                if verbose:
                    print("❌ Failed: Binary function requires exactly two inputs")
                return False

            acc_concept = inputs[1]
            if not isinstance(acc_concept, Concept):
                if verbose:
                    print("❌ Failed: Second input must be a concept")
                return False

            if acc_concept.get_input_arity() != 0:
                if verbose:
                    print("❌ Failed: Second input must be a constant (arity 0)")
                return False

            if verbose:
                print(f"✓ Second input is a constant concept: {acc_concept.name}")
        else:
            if verbose:
                print("❌ Failed: Function must be unary or binary")
            return False

        # Check that all types are numeric
        component_types = base_concept.get_component_types()
        if not all(t == ExampleType.NUMERIC for t in component_types):
            if verbose:
                print("❌ Failed: Function must map numeric domain to numeric domain")
            return False

        # Check that output arity is 1
        output_arity = len(component_types) - input_arity
        if output_arity != 1:
            if verbose:
                print(
                    f"❌ Failed: Function must have output arity of 1, but got {output_arity}"
                )
            return False

        # For binary function case, check the accumulator concept
        if input_arity == 2 and len(inputs) == 2:
            acc_concept = inputs[1]
            if not isinstance(acc_concept, Concept):
                if verbose:
                    print("❌ Failed: Second input must be a concept")
                return False

            if (
                not hasattr(acc_concept.examples, "example_structure")
                or acc_concept.examples.example_structure.concept_type
                not in [ConceptType.FUNCTION, ConceptType.CONSTANT]
                or acc_concept.get_input_arity() != 0
            ):
                if verbose:
                    print(
                        "❌ Failed: Second input must be a constant concept (arity 0)"
                    )
                return False

        if verbose:
            print(
                f"✓ Valid {'binary' if input_arity == 2 else 'unary'} function domain -> domain with output arity 1"
            )
        return True

    def determine_verification_capabilities(self, *inputs: Entity) -> Tuple[bool, bool]:
        """
        Determine whether the resulting map_iterate concept can verify examples and nonexamples.

        For both unary and binary function cases:
        - Can verify examples if the input function can verify examples
        - Can verify nonexamples if the input function can verify nonexamples

        For binary functions, we also need the accumulator concept to be reliable.

        Args:
            *inputs: Input entities to determine capabilities for

        Returns:
            Tuple[bool, bool]: (can_add_examples, can_add_nonexamples)
        """
        if not inputs or not all(isinstance(x, Concept) for x in inputs):
            return (False, False)

        base_concept = inputs[0]
        is_binary = base_concept.examples.example_structure.input_arity == 2

        # Default to assuming capabilities are reliable unless we know otherwise
        can_add_examples = base_concept.can_add_examples
        can_add_nonexamples = base_concept.can_add_nonexamples

        if is_binary and len(inputs) > 1:
            acc_concept = inputs[1]
            # For binary function, we also need the accumulator to be reliable
            acc_can_add_examples = acc_concept.can_add_examples
            acc_can_add_nonexamples = acc_concept.can_add_nonexamples

            can_add_examples = can_add_examples and acc_can_add_examples
            can_add_nonexamples = can_add_nonexamples and acc_can_add_nonexamples

        return (can_add_examples, can_add_nonexamples)

    def apply(self, *inputs: Entity) -> Entity:
        """
        Apply the map iteration rule to create a new concept.

        For unary function f: domain -> domain:
        - Creates g: domain × nat -> domain where g(a,n) applies f to a, n times

        For binary function f: domain × domain -> domain and accumulator acc:
        - Creates g: domain × nat -> domain where g(a,n) applies f(a, _) n times starting from acc
        """
        if not self.can_apply(*inputs):
            raise ValueError("Cannot apply MapIteration to these inputs")

        base_concept = inputs[0]
        is_binary = base_concept.get_input_arity() == 2

        # Determine verification capabilities
        can_add_examples, can_add_nonexamples = (
            self.determine_verification_capabilities(*inputs)
        )

        def iterate_compute(a, n):
            """Helper function to compute iteration"""
            if n < 0:
                raise ValueError("Number of iterations must be non-negative")

            if is_binary:
                acc_concept = inputs[1]
                result = acc_concept.compute()
                for _ in range(n):
                    result = base_concept.compute(a, result)
            else:
                # For addition: a + n = applying successor to a, n times
                # When n = 0, we should return a (no iterations)
                result = a
                for _ in range(n):
                    result = base_concept.compute(result)
            return result

        # Create the new concept that represents folding/iteration
        new_concept = Concept(
            name=(
                f"iterate_({base_concept.name})"
                if not is_binary
                else f"iterate_({base_concept.name}_with_{inputs[1].name})"
            ),
            description=f"Applies {base_concept.name} iteratively",
            symbolic_definition=lambda a, n: Fold(
                n,  # Number of iterations
                inputs[1].symbolic() if is_binary else a,  # Starting value
                Lambda(
                    "x",
                    (
                        ConceptApplication(base_concept, a, Var("x"))
                        if is_binary
                        else ConceptApplication(base_concept, Var("x"))
                    ),
                ),  # Step function
            ),
            computational_implementation=iterate_compute,
            example_structure=ExampleStructure(
                concept_type=ConceptType.FUNCTION,
                component_types=(
                    ExampleType.NUMERIC,
                    ExampleType.NUMERIC,
                    ExampleType.NUMERIC,
                ),
                input_arity=2,  # Takes two inputs: value and iteration count
            ),
            can_add_examples=can_add_examples,
            can_add_nonexamples=can_add_nonexamples,
            z3_translation=None,
        )

        # Set map_iterate_depth
        new_concept.map_iterate_depth = max(getattr(inp, 'map_iterate_depth', 0) for inp in inputs) + 1

        # Transform examples, passing accumulator if needed
        self._transform_examples(
            new_concept, 
            base_concept, 
            accumulator_concept=(inputs[1] if is_binary else None)
        )

        return new_concept

    def _transform_examples(self, new_concept: Entity, base_concept: Entity, accumulator_concept: Optional[Entity] = None):
        """
        Transform examples using a hybrid approach for both unary and binary base concepts:
        1. Follow chains of examples up to MAX_CHAIN_DEPTH with timeout.
        2. Compute examples directly for small n values not covered by chains, with timeout.
        Generates corresponding non-examples for successful positive examples.
        """
        executor = None
        transform_start_time = time.time() # Overall start time

        try:
            executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
            base_examples = base_concept.get_examples()
            input_arity = base_concept.get_input_arity()

            # --- UNARY CASE: f(a) = b -> g(a, n) --- 
            if input_arity == 1:
                lookup: Dict[Any, Any] = {}
                for ex in base_examples: # Build lookup quickly first
                    if isinstance(ex.value, tuple) and len(ex.value) == 2:
                        a, b = ex.value
                        if a not in lookup: lookup[a] = b

                processed_starts = set()
                total_chain_section_start_time = time.time() # Start timer for all chaining
                
                # --- Chain Following Section (Unary) ---
                logger.debug(f"Starting unary chain following for {new_concept.name}...")
                for example in base_examples:
                    # Check TOTAL chain timeout
                    if time.time() - total_chain_section_start_time > TOTAL_CHAIN_TIMEOUT:
                        logger.warning(f"Total chain following time exceeded {TOTAL_CHAIN_TIMEOUT}s for {new_concept.name}. Stopping chain processing.")
                        break 
                    
                    if not isinstance(example.value, tuple) or len(example.value) != 2: continue
                    start_a, start_b = example.value
                    if start_a in processed_starts: continue
                    processed_starts.add(start_a)

                    local_added_n = set() 
                    try:
                        self._add_example_and_non(new_concept, start_a, 0, start_a)
                        local_added_n.add(0)
                        self._add_example_and_non(new_concept, start_a, 1, start_b)
                        local_added_n.add(1)

                        current_b = start_b
                        current_n = 1
                        # Inner loop timeout for a single chain
                        single_chain_start_time = time.time()
                        for _ in range(MAX_CHAIN_DEPTH):
                            # Check inner loop timeout
                            if time.time() - single_chain_start_time > DEFAULT_CHAIN_SEARCH_TIMEOUT: 
                                logger.debug(f"Single chain search timed out for a={start_a}")
                                break 
                            next_c = lookup.get(current_b)
                            if next_c is not None:
                                current_n += 1
                                try:
                                    self._add_example_and_non(new_concept, start_a, current_n, next_c)
                                    local_added_n.add(current_n)
                                    current_b = next_c
                                except Exception as add_err: break
                            else: break
                    except Exception as chain_err: logger.error(f"Unary chain error for {new_concept.name}, a={start_a}: {chain_err}")
                logger.debug(f"Finished unary chain following section for {new_concept.name}.")
                # --- End Chain Following Section (Unary) --- 
                
                # --- Direct Computation Section (Unary) ---
                logger.debug(f"Starting unary direct computation for {new_concept.name}...")
                processed_starts_compute = set() # Re-track starts for compute section
                total_compute_section_start_time = time.time() # Start timer for all direct compute
                for example in base_examples: # Need to iterate again to get start_a values
                    # Check TOTAL compute timeout
                    if time.time() - total_compute_section_start_time > TOTAL_COMPUTE_TIMEOUT:
                        logger.warning(f"Total direct computation time exceeded {TOTAL_COMPUTE_TIMEOUT}s for {new_concept.name}. Stopping computation processing.")
                        break
                        
                    if not isinstance(example.value, tuple) or len(example.value) != 2: continue
                    start_a, _ = example.value # Only need start_a here
                    if start_a in processed_starts_compute: continue
                    processed_starts_compute.add(start_a)
                    
                    # Get the set of n already added for this start_a during chaining
                    # We need a way to access this... maybe store globally in `added_an_pairs`? Let's refactor.
                    # For now, assume `local_added_n` is somehow accessible or recomputed if needed.
                    # Refactoring required here - local_added_n is not available outside the chain loop.
                    # TEMPORARY: Recompute `local_added_n` based on `new_concept.examples` for this `start_a` - inefficient!
                    current_examples_for_a = {ex.value[1] for ex in new_concept.get_examples() if ex.value[0] == start_a}

                    for n_compute in range(2, MAX_CHAIN_DEPTH + 2):
                        # Also check total compute time here
                        if time.time() - total_compute_section_start_time > TOTAL_COMPUTE_TIMEOUT: break
                        
                        if n_compute not in current_examples_for_a: # Check against current examples
                            self._try_direct_compute_example(executor, new_concept, start_a, n_compute)
                    if time.time() - total_compute_section_start_time > TOTAL_COMPUTE_TIMEOUT: break # Break outer loop too
                logger.debug(f"Finished unary direct computation section for {new_concept.name}.")
                # --- End Direct Computation Section (Unary) ---

            # --- BINARY CASE: f(a, x) = y, acc = c -> g(a, n) --- 
            elif input_arity == 2:
                # (Similar structure needed here with TOTAL_CHAIN_TIMEOUT and TOTAL_COMPUTE_TIMEOUT checks)
                if accumulator_concept is None: # Basic checks first
                    logger.warning(f"MapIterate (binary) ... Skipping example transformation.")
                    return
                try: 
                    if accumulator_concept.get_input_arity() != 0: # Check arity
                        logger.warning(f"Accumulator {accumulator_concept.name}... not arity 0... Skipping.")
                        return
                    acc_value = accumulator_concept.compute() # Compute accumulator value
                except Exception as acc_err: logger.error(f"Failed compute acc value... Skipping."); return

                # Preprocessing
                lookup_by_a: Dict[Any, Dict[Any, Any]] = defaultdict(dict)
                unique_a_values = set()
                for ex in base_examples:
                    if isinstance(ex.value, tuple) and len(ex.value) == 3:
                        a, x, y = ex.value
                        unique_a_values.add(a)
                        if x not in lookup_by_a[a]: lookup_by_a[a][x] = y
                
                # --- Chain Following Section (Binary) ---
                logger.debug(f"Starting binary chain following for {new_concept.name}...")
                total_chain_section_start_time = time.time() # Start timer
                for start_a in unique_a_values:
                    # Check TOTAL chain timeout
                    if time.time() - total_chain_section_start_time > TOTAL_CHAIN_TIMEOUT: 
                        logger.warning(f"Total chain following time exceeded {TOTAL_CHAIN_TIMEOUT}s... Stopping."); break
                        
                    local_added_n = set()
                    if start_a not in lookup_by_a: continue
                    lookup_for_a = lookup_by_a[start_a]
                    try:
                        self._add_example_and_non(new_concept, start_a, 0, acc_value)
                        local_added_n.add(0)
                        
                        current_y = acc_value
                        current_n = 0
                        # Inner loop timeout for a single chain
                        single_chain_start_time = time.time()
                        for _ in range(MAX_CHAIN_DEPTH + 1):
                            if time.time() - single_chain_start_time > DEFAULT_CHAIN_SEARCH_TIMEOUT: 
                                logger.debug(f"Single chain search timed out for a={start_a}"); break
                            next_y = lookup_for_a.get(current_y)
                            if next_y is not None:
                                current_n += 1
                                try:
                                    self._add_example_and_non(new_concept, start_a, current_n, next_y)
                                    local_added_n.add(current_n)
                                    current_y = next_y
                                except Exception as add_err: break
                            else: break
                    except Exception as chain_err: logger.error(f"Binary chain error for {new_concept.name}, a={start_a}: {chain_err}")
                logger.debug(f"Finished binary chain following section for {new_concept.name}.")
                # --- End Chain Following Section (Binary) --- 

                # --- Direct Computation Section (Binary) ---
                logger.debug(f"Starting binary direct computation for {new_concept.name}...")
                total_compute_section_start_time = time.time() # Start timer
                for start_a in unique_a_values: # Iterate again over unique a
                    # Check TOTAL compute timeout
                    if time.time() - total_compute_section_start_time > TOTAL_COMPUTE_TIMEOUT:
                        logger.warning(f"Total direct computation time exceeded {TOTAL_COMPUTE_TIMEOUT}s... Stopping."); break
                        
                    # TEMPORARY: Recompute `local_added_n` inefficiently!
                    current_examples_for_a = {ex.value[1] for ex in new_concept.get_examples() if len(ex.value)==3 and ex.value[0] == start_a}
                        
                    for n_compute in range(1, MAX_CHAIN_DEPTH + 2):
                        # Also check total compute time here
                        if time.time() - total_compute_section_start_time > TOTAL_COMPUTE_TIMEOUT: break
                        
                        if n_compute not in current_examples_for_a: # Check against current examples
                             self._try_direct_compute_example(executor, new_concept, start_a, n_compute)
                    if time.time() - total_compute_section_start_time > TOTAL_COMPUTE_TIMEOUT: break # Break outer loop too
                logger.debug(f"Finished binary direct computation section for {new_concept.name}.")
                # --- End Direct Computation Section (Binary) ---

            else:
                 logger.warning(f"MapIterate rule does not support transforming examples for base concept arity {input_arity}")

        finally:
            if executor:
                executor.shutdown(wait=False)

    # --- Helper Methods for _transform_examples ---
    def _add_example_and_non(self, concept: Entity, a_val: Any, n_val: int, actual_result: Any):
        """Helper to add an example and its corresponding non-example."""
        example_tuple = (a_val, n_val, actual_result)
        # TODO(optim): Skip adding examples/non-examples with large values (>50) as a temporary optimization.
        if any(isinstance(x, (int, float)) and x > 50 for x in example_tuple):
            return
            
        try:
            concept.add_example(example_tuple)
            # Attempt to add non-example based on the result
            self._try_add_non_example(concept, a_val, n_val, actual_result)
        except Exception as add_err:
             logger.warning(f"Failed adding example/non-example pair for ({a_val}, {n_val}): {add_err}")
             
    def _try_add_non_example(self, concept: Entity, a_val: Any, n_val: int, actual_result: Any):
        """Helper to attempt adding a perturbed non-example."""
        try:
            diff_result = random.randint(0, 10) # Generate random int
            # Ensure type matches actual_result if possible before comparing
            if type(diff_result) != type(actual_result):
                try:
                    diff_result = type(actual_result)(diff_result) # Attempt type conversion
                except (ValueError, TypeError):
                    logger.debug(f"Cannot convert random value {diff_result} to type {type(actual_result)} for non-example.")
                    return # Cannot compare if types mismatch and conversion fails
                    
            if diff_result != actual_result: # Ensure it's actually different
                non_example = (a_val, n_val, diff_result)
                # TODO(optim): Skip adding non-examples with large values (>50) as a temporary optimization.
                if any(isinstance(x, (int, float)) and x > 50 for x in non_example):
                    return
                concept.add_nonexample(non_example)
                
        except Exception as non_ex_add_err:
            # Log less verbosely for non-example failures
            logger.debug(f"Failed to add non-example for ({a_val}, {n_val}): {non_ex_add_err}")

    def _try_direct_compute_example(self, executor, concept: Entity, a_val: Any, n_val: int):
        """Helper to compute an example directly with timeout and add example/non-example."""
        compute_input_tuple = (a_val, n_val)
        try:
            future = executor.submit(concept.compute, *compute_input_tuple)
            computed_output = future.result(timeout=DEFAULT_DIRECT_COMPUTE_TIMEOUT)
            
            # Use the helper to add example and non-example
            self._add_example_and_non(concept, a_val, n_val, computed_output)
                 
        except FuturesTimeoutError:
            logger.warning(f"Direct compute timeout ({DEFAULT_DIRECT_COMPUTE_TIMEOUT}s) for {concept.name} input {compute_input_tuple}. Skipping.")
        except Exception as compute_err:
             logger.error(f"Error during direct compute for {concept.name} input {compute_input_tuple}: {compute_err}")
