"""Data structures for managing learned abstractions in FunSearch."""
import dataclasses
from typing import Optional, List, Dict
import logging
import textwrap
import random # Add random import
import ast # Add ast import

logger = logging.getLogger(__name__)

@dataclasses.dataclass(frozen=True)
class Abstraction:
    """Represents a single learned abstraction function."""
    name: str         # e.g., _abstraction_lib_func_1
    signature: str    # e.g., (entity_id: str, graph: KnowledgeGraph) -> float
    description: str  # Concise natural language docstring from LLM
    code: str         # Full Python code definition of the function, including 'def' line

    def get_code_without_docstring(self) -> str:
        """
        Returns the function's code without its docstring.
        If parsing fails or no docstring is found, returns the original code.
        """
        try:
            tree = ast.parse(self.code)
            if not tree.body or not isinstance(tree.body[0], ast.FunctionDef):
                # Not a function definition at the top level, return original code
                return self.code

            func_def_node = tree.body[0]

            # Check if the first statement in the function body is a docstring
            if (func_def_node.body and 
                isinstance(func_def_node.body[0], ast.Expr) and
                isinstance(func_def_node.body[0].value, ast.Constant) and
                isinstance(func_def_node.body[0].value.value, str)):
                # Remove the docstring node
                func_def_node.body = func_def_node.body[1:]
            
            return ast.unparse(tree)
        except SyntaxError:
            # If parsing fails, return the original code as a fallback
            logger.warning(f"Could not parse code for abstraction '{self.name}' to remove docstring. Returning original code.")
            return self.code
        except Exception as e:
            logger.error(f"Unexpected error removing docstring for abstraction '{self.name}': {e}. Returning original code.")
            return self.code

class AbstractionLibrary:
    """Manages a collection of learned Abstraction objects for a single island."""
    def __init__(self, max_prompt_chars: int = 32768): # Approx 8000 tokens
        self._abstractions: Dict[str, Abstraction] = {} # name -> Abstraction
        self._max_prompt_chars = max_prompt_chars
        logger.info(f"AbstractionLibrary initialized with max_prompt_chars: {self._max_prompt_chars}")

    def add_abstraction(self, abstraction: Abstraction) -> bool:
        """Adds a new abstraction to the library. Returns True if added, False otherwise."""
        if not isinstance(abstraction, Abstraction):
             logger.error(f"Attempted to add non-Abstraction object: {type(abstraction)}")
             return False
        if not abstraction.name or not abstraction.code:
             logger.error(f"Attempted to add abstraction with missing name or code: {abstraction}")
             return False
             
        if abstraction.name in self._abstractions:
            logger.warning(f"Abstraction '{abstraction.name}' already exists in this library. Skipping.")
            return False
            
        self._abstractions[abstraction.name] = abstraction
        logger.info(f"Added abstraction '{abstraction.name}' to library: {abstraction.description[:80]}...")
        # Log code separately to avoid f-string issues with newlines/quotes
        logger.debug(f"Full code for '{abstraction.name}':")
        logger.debug(abstraction.code) # Log raw code on its own line
        return True

    def get_abstractions(self) -> List[Abstraction]:
        """Returns a list of all Abstraction objects in the library."""
        return list(self._abstractions.values())

    def __len__(self) -> int:
        """Returns the number of abstractions in the library."""
        return len(self._abstractions)

    def format_for_sampler_prompt(self) -> str:
        """
        Generates a formatted string of available abstractions for the Sampler's system prompt.
        Includes full function definitions without docstrings.
        Prunes abstractions if the total character count exceeds _max_prompt_chars.
        Leaves a buffer of ~1000 chars (approx 250 tokens) for the rest of the prompt and response.
        """
        if not self._abstractions:
            return "  (No abstraction functions available in this island yet.)"

        prompt_buffer = 4096 # Characters to leave for the rest of the prompt and expected output
        effective_max_chars = self._max_prompt_chars - prompt_buffer

        header = "**Available Abstraction Functions (Implementations):**\n"
        current_chars = len(header)
        
        selected_abstractions_code = []
        
        # Sort abstractions: shorter ones first, then alphabetically by name for tie-breaking
        # This way, we prioritize including more, smaller abstractions if space is limited.
        sorted_abstractions = sorted(
            self._abstractions.values(),
            key=lambda abs_obj: (len(abs_obj.get_code_without_docstring()), abs_obj.name)
        )

        for abs_obj in sorted_abstractions:
            code_without_docstring = abs_obj.get_code_without_docstring()
            # Format: def name(args) -> ret:\n  implementation...
            # The get_code_without_docstring should already provide this structure.
            # We just need to ensure it's clean.
            
            # Estimate additional characters: code length + formatting (e.g., two newlines)
            # No longer adding "```python" and "```" for each one, just concatenating defs.
            code_block_to_add = code_without_docstring + "\n\n" # Add two newlines as separator
            cost = len(code_block_to_add)

            if current_chars + cost <= effective_max_chars:
                selected_abstractions_code.append(code_block_to_add)
                current_chars += cost
            else:
                logger.info(f"Reached character limit ({effective_max_chars} for abstractions). Pruning remaining {len(sorted_abstractions) - len(selected_abstractions_code)} abstractions for prompt.")
                break
        
        if not selected_abstractions_code:
            # This might happen if even the header + buffer exceeds max_prompt_chars,
            # or if the very first abstraction is too large.
            logger.warning("No abstractions could be formatted within the character limit for the prompt.")
            return "  (No abstraction functions could be included due to token limits.)"

        final_prompt_str = header + "".join(selected_abstractions_code)
        logger.info(f"Formatted abstractions for prompt. Total chars: {current_chars} (limit for section: {effective_max_chars}). Abstractions included: {len(selected_abstractions_code)}/{len(self._abstractions)}")
        return final_prompt_str

    def generate_definitions_for_sandbox(self) -> str:
        """Generates a string containing Python code definitions for all abstractions.
        DEPRECATED: Use get_definitions_file_content instead.
        """
        # Keep this for potential compatibility if needed elsewhere, but log warning
        logger.warning("generate_definitions_for_sandbox is deprecated. Use get_definitions_file_content.")
        if not self._abstractions:
            return ""
        # Simple concatenation for the old method
        return "\n\n".join(abstraction.code for abstraction in self._abstractions.values()) + "\n\n"

    def get_definitions_file_content(self, include_imports: bool = True) -> str:
        """Generates the full content for a Python file containing all abstraction definitions.

        Args:
            include_imports: Whether to include a standard set of imports typically
                             needed by abstractions (math, typing, numpy, KnowledgeGraph).

        Returns:
            A string representing the content of the Python file.
        """
        # If there are no abstractions, return an empty string immediately.
        if not self._abstractions:
            # logger.debug("Abstraction library is empty, returning empty content string.")
            return ""
            
        content_parts = []

        # 1. Add standard imports if requested (only if abstractions exist)
        if include_imports:
            # TODO(_; 4/16): Consider making these imports configurable or dynamically detected (interestingness prompt already has this)
            standard_imports = [
                "import math",
                "import numpy as np",
                "from typing import Dict, Any, List, Set, Optional, Union",
                "from frame.knowledge_base.knowledge_graph import KnowledgeGraph, NodeType",
                "# Assuming primitives might be needed by some complex abstractions",
                "from frame.interestingness.learning.dsl_primitives import * # Import all primitives"
            ]
            content_parts.extend(standard_imports)
            content_parts.append("\n") # Add space after imports

        # 2. Add the code for each abstraction (we know abstractions exist here)
        sorted_abstractions = sorted(self._abstractions.values(), key=lambda x: x.name)
        for abstraction in sorted_abstractions:
            content_parts.append(abstraction.code)
            content_parts.append("\n") # Add space between functions

        return "\n".join(content_parts)

    def __len__(self) -> int:
        """Returns the number of abstractions in the library."""
        return len(self._abstractions) 