"""Interestingness function generation using LLMs.

This module implements the LLM-based approach for generating
interestingness functions using a one-shot learning approach.
"""

import os
import re
import ast
import yaml
import time
import asyncio
import inspect
import logging
import traceback
from typing import Dict, List, Any, Optional, Callable, Union, Tuple
from datetime import datetime
import secrets

from frame.knowledge_base.knowledge_graph import KnowledgeGraph
from frame.tools.llm_caller import LLMCaller
from frame.interestingness.learning.base import BaseInterestingnessGenerator
from frame.interestingness.learning.dsl_primitives import ALL_PRIMITIVES


class OneShotLLMGenerator(BaseInterestingnessGenerator):
    """
    Generate interestingness functions using Large Language Models.
    
    This class handles:
    1. Prompt creation with available primitives
    2. LLM interaction to generate Python code for interestingness functions
    3. Validation and loading of generated functions
    """
    
    def __init__(
        self,
        model: Optional[str] = None,
        temperature: Optional[float] = None,
        prompt_template: Optional[str] = None,
        model_config_path: str = "frame/configs/models/gpt4o-mini.yaml",
        prompt_template_path: Optional[str] = None,
        output_dir: str = "frame/interestingness/learning/generated_programs",
        logger: Optional[logging.Logger] = None
    ):
        """
        Initialize the LLM interestingness function generator.
        
        Args:
            model: LLM model to use (overrides model_config)
            temperature: Temperature for generation (overrides model_config)
            prompt_template: Template for the prompt as a string (if provided directly)
            model_config_path: Path to the model configuration YAML
            prompt_template_path: Path to the prompt template YAML (if not provided directly)
            output_dir: Directory to save generated functions
            logger: Optional logger instance
        """
        # Initialize the base class
        super().__init__(config_path=model_config_path, output_dir=output_dir, logger=logger)
        
        # Store parameters
        self.model = model
        self.temperature = temperature
        self.prompt_template_path = prompt_template_path
        
        # Load and prepare the model configuration
        self.model_config = self._load_model_config(model_config_path)
        if model:
            self.model_config.name = model
        if temperature is not None:
            self.model_config.temperature = temperature
            
        # Store the prompt template
        if prompt_template is None and prompt_template_path:
            # Load the entire YAML, but store only the 'one_shot_prompt' section 
            # if it exists, otherwise store the whole structure for potential compatibility
            full_config = self._load_prompt_template(prompt_template_path)
            self.prompt_template = full_config.get('one_shot_prompt', full_config) 
            # Load required imports from the top level
            self._required_imports = full_config.get('required_imports', '').strip() + '\n'
            # Load primitives from the top level
            self._primitive_categories = full_config.get('primitive_categories', {})
        elif isinstance(prompt_template, str): 
             # If a raw string template is provided, we can't easily load imports/primitives
             self.logger.warning("Raw string prompt template provided. Cannot automatically load imports or primitives from YAML.")
             self._required_imports = "" # Cannot determine imports
             self._primitive_categories = {} # Cannot determine primitives
        else: # Prompt template provided directly as dict?
             self.prompt_template = prompt_template # Assume it's the correct section
             # Try loading imports/primitives from it, or assume they are handled externally
             self._required_imports = prompt_template.get('required_imports', '').strip() + '\n' # Unlikely to be here
             self._primitive_categories = prompt_template.get('primitive_categories', {}) # Unlikely to be here
             if not self._required_imports or not self._primitive_categories:
                  self.logger.warning("Prompt template dict provided, but missing 'required_imports' or 'primitive_categories'. Imports/primitives might be missing.")
        
        # Initialize LLM caller
        self.llm_caller = LLMCaller(model_config=self.model_config, logger=self.logger)
        
        # Track our generated function code
        self.generated_function_code = None
        
    def _load_model_config(self, config_path: str) -> Any:
        """
        Load a model configuration from a YAML file.
        
        Args:
            config_path: Path to the YAML config file
            
        Returns:
            Model configuration object
        """
        try:
            # Resolve the path correctly
            resolved_path = self._resolve_path(config_path)
            
            self.logger.info(f"Loading model config from: {resolved_path}")
            
            # Check if the file exists
            if not os.path.exists(resolved_path):
                self.logger.error(f"Model config file not found at {resolved_path}")
                raise FileNotFoundError(f"Model config file not found at {resolved_path}")
                
            # Load the config
            config_dict = self._load_yaml_config(resolved_path)
                
            return config_dict # Return the dict directly
        except Exception as e:
            self.logger.error(f"Error loading model config: {e}")
            # Create a default config
            class DefaultConfig:
                def __init__(self):
                    self.name = "gpt-4o"
                    self.temperature = 1.0
                    self.max_tokens = 4096
                    self.system_prompt = "You are a mathematical assistant that specializes in creating interestingness functions."
                    
            return DefaultConfig()
        
    def _load_prompt_template(self, template_path: str) -> Dict[str, Any]:
        """
        Load prompt template from a YAML file.
        
        Args:
            template_path: Path to the YAML template file
            
        Returns:
            Loaded template as a dictionary
        """
        try:
            # Resolve the path
            resolved_path = self._resolve_path(template_path)
            
            self.logger.info(f"Loading prompt template from: {resolved_path}")
            
            # Load the template dictionary
            return self._load_yaml_config(resolved_path)
            
        except Exception as e:
            self.logger.error(f"Error loading prompt template: {e}")
            raise
    
    def format_primitives(self, primitive_data: dict) -> str:
        """
        Format DSL primitives from the loaded dictionary for inclusion in the prompt.
        
        Args:
            primitive_data: Dictionary containing primitive categories and details.
            
        Returns:
            Formatted string for inclusion in the prompt.
        """
        primitives_text = []
        # Use the 'primitive_categories' structure directly from the loaded YAML
        for category, data in primitive_data.items():
            primitives_text.append(f"### {category}")
            if 'description' in data:
                 primitives_text.append(data['description'])
            
            for primitive in data.get('primitives', []):
                 name = primitive.get('name', 'N/A')
                 desc = primitive.get('description', 'No description available.')
                 return_type = primitive.get('return_type', 'Any') 
                 primitives_text.append(f"- {name}(entity_id, graph) -> {return_type}: {desc}")
            
            # Add a blank line after each category
            primitives_text.append("")
        
        return "\n".join(primitives_text)
    
    def generate_prompt(self, 
                       template_data: Optional[Dict[str, Any]] = None,
                       primitive_data: Optional[Dict[str, Any]] = None,
                       additional_context: Optional[str] = None) -> str:
        """
        Generate a prompt for the LLM to create an interestingness function.
        Uses the 'one_shot_prompt' structure from the loaded YAML data.
        
        Args:
            template_data: The dictionary containing the 'one_shot_prompt' structure.
                           If None, uses self.prompt_template.
            primitive_data: Dictionary containing primitive categories and details.
                           If None, uses self._primitive_categories.
            additional_context: Optional additional context to add to the prompt
            
        Returns:
            Formatted prompt string
        """
        # Use provided data or instance attributes
        if template_data is None:
             template_data = self.prompt_template
        if primitive_data is None:
             primitive_data = self._primitive_categories
             
        # Ensure template_data is a dictionary (it should be if loaded correctly)
        if not isinstance(template_data, dict):
             self.logger.error("Prompt template data is not a dictionary. Cannot generate prompt.")
             raise ValueError("Invalid prompt template data provided.")
             
        # Extract sections from the one_shot_prompt structure
        # Use .get for safety, although loading should handle missing keys via errors now
        title = self.prompt_template.get('title', '') # Title is top-level
        intro = template_data.get('introduction', '') # Intro is top-level
        context = template_data.get('context', '') # Context is top-level
        primitives_intro = template_data.get('primitives_intro', '')
        task_desc = template_data.get('task', '')
        code_template = template_data.get('template', '')
        closing = template_data.get('closing', '')
        
        # Format the primitives using the loaded primitive data
        formatted_primitives = self.format_primitives(primitive_data)
        
        # Build sections of the prompt
        sections = [
            title,
            '',
            intro,
            '',
            context,
            '',
            primitives_intro,
            '',
            formatted_primitives
        ]
        
        # Add additional context if provided
        if additional_context:
            sections.append('')
            sections.append("## Additional Context")
            sections.append(additional_context)
            
        # Add task description and template
        sections.append('')
        sections.append(task_desc)
        sections.append('')
        sections.append(code_template)
        sections.append('')
        sections.append(closing)
        
        # Join sections to create final prompt
        prompt = '\n'.join(sections)
        return prompt
    
    def save_prompt(self, prompt: str, output_path: str) -> None:
        """
        Save the generated prompt to a file.
        
        Args:
            prompt: The prompt text to save
            output_path: Path to save the prompt to
        """
        # Ensure the directory exists
        os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
        
        # Write the prompt to the file
        with open(output_path, 'w') as f:
            f.write(prompt)
    
    async def generate(self, conversation_id: Optional[str] = None, **kwargs) -> Tuple[Callable, Optional[str]]:
        """
        Generate an interestingness function using the LLM.
        
        Args:
            conversation_id: Unique identifier for the conversation
            **kwargs: Additional arguments for generation
            
        Returns:
            Tuple of (interestingness_function, path_if_saved)
        """
        # Generate a unique identifier if none provided
        if conversation_id is None:
            # Use a more robust ID format with timestamp and random hash
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            random_hash = secrets.token_hex(4)  # 8-character hex string
            conversation_id = f"interestingness_{timestamp}_{random_hash}"
        
        # Extract relevant kwargs
        prompt_template_path = kwargs.get('prompt_template_path', None)
        additional_context = kwargs.get('additional_context', None)
        save_function = kwargs.get('save_function', True)
        
        try:
            # Prepare the prompt
            prompt_str = ""
            if self.prompt_template is None:
                # If no prompt template provided, generate one
                self.logger.info("No prompt template provided, generating default prompt.")
                
                # Generate the prompt
                prompt_str = self.generate_prompt(
                    template_data=self.prompt_template,
                    primitive_data=self._primitive_categories,
                    additional_context=additional_context
                )
            else:
                # If prompt_template is already set, ensure it's a string
                if isinstance(self.prompt_template, dict):
                    # It's a loaded YAML template, convert it to a string
                    prompt_str = self.generate_prompt(
                        template_data=self.prompt_template,
                        primitive_data=self._primitive_categories,
                        additional_context=additional_context
                    )
                else:
                    # It's already a string
                    prompt_str = self.prompt_template
            
            # --- Add logging --- #
            self.logger.debug(f"OneShotLLMGenerator: Using model config: {self.model_config}")
            model_name_to_use = self.model_config.get('name') # Explicitly get name
            self.logger.debug(f"OneShotLLMGenerator: Extracted model name: {model_name_to_use}")
            if not model_name_to_use:
                self.logger.error("Model name is missing from model_config! Cannot call LLM.")
                raise ValueError("Model name is missing in the loaded configuration.")
            # --- End logging --- #
                
            # Call the LLM
            self.logger.info(f"Calling LLM to generate interestingness function for {conversation_id}")
            response = await self.llm_caller.call_model(
                conversation_id=conversation_id, 
                prompt=prompt_str,  # Use string prompt
                model=model_name_to_use, # Pass explicitly extracted name
                temperature=self.model_config.get('temperature', 0.7),
                max_tokens=self.model_config.get('max_tokens', None),
                system_prompt=self.model_config.get('system_prompt', None)
            )
            
            # Extract content from the response, handling different response formats
            content = ""
            if isinstance(response, dict):
                # Handle dictionary response (like from OpenAI API)
                if 'choices' in response and len(response['choices']) > 0:
                    if 'message' in response['choices'][0]:
                        content = response['choices'][0]['message']['content']
                    elif 'text' in response['choices'][0]:
                        content = response['choices'][0]['text']
                # Hugging Face format
                elif 'content' in response:
                    content = response['content']
            else:
                # Handle object-like response
                try:
                    content = response.choices[0].message.content
                except (AttributeError, IndexError):
                    try:
                        content = response.choices[0].text
                    except (AttributeError, IndexError):
                        raise ValueError("Could not extract content from LLM response")
            
            # Extract the function code using regex pattern for code blocks
            function_content = self._extract_code_from_response(content)
            if not function_content:
                raise ValueError("No valid code found in the LLM response.")
            
            # Store the generated code
            self.generated_function_code = function_content
            
            # Create a callable function from the code
            interestingness_function = self._create_function_from_code(function_content, conversation_id)
            
            # Save the function if requested
            function_path = None
            if save_function:
                function_path = self.save(function_content, conversation_id)
            
            # Return the function and path
            return interestingness_function, function_path
            
        except Exception as e:
            self.logger.error(f"Error generating interestingness function: {e}")
            self.logger.error(traceback.format_exc())
            raise
    
    def _extract_code_from_response(self, response_text: str) -> str:
        """
        Extract Python code from the LLM response.
        
        Args:
            response_text: The LLM response text
            
        Returns:
            Extracted Python code as string
        """
        # Use regex to find code blocks (```python ... ```)
        python_code_pattern = r"```(?:python)?\s*([\s\S]*?)```"
        matches = re.findall(python_code_pattern, response_text)
        
        if not matches:
            # If no code blocks found, try to extract the function directly
            function_pattern = r"def calculate_interestingness\(.*?\)[\s\S]*?(?:return|pass)"
            matches = re.findall(function_pattern, response_text)
            
        if matches:
            # Take the first match or combine multiple matches
            return "\n".join(matches) 
        else:
            # No code found
            return ""
    
    def _create_function_from_code(self, code: str, function_id: str) -> Callable:
        """
        Create a function from code string.
        
        Args:
            code: Python code string
            function_id: Unique identifier for the function
            
        Returns:
            Callable function
        """
        try:
            # Parse the code to ensure it's valid Python
            ast.parse(code)
            
            # Create a namespace to exec the code in
            namespace = {}
            
            # Add necessary imports to the namespace
            # Note: exec runs in the provided namespace dict
            exec("import math", namespace)
            exec("import numpy as np", namespace)
            exec("from frame.knowledge_base.knowledge_graph import KnowledgeGraph", namespace)
            # Also provide commonly used types if the LLM uses them without importing
            exec("from typing import Any, List, Dict, Optional, Tuple, Union", namespace)
            
            # Prepend common imports to the generated code string itself as a safety measure
            # This helps if the code relies on these imports directly during execution
            code_prefix = "\n".join([
                 "from typing import Any, List, Dict, Optional, Tuple, Union",
                 "import math",
                 "import numpy as np",
                 "from frame.knowledge_base.knowledge_graph import KnowledgeGraph",
                 "# Add more common imports here if needed"
                 "\n"
            ])
            full_code = code_prefix + code

            # Execute the modified code in the namespace
            try:
                exec(full_code, namespace)
            except NameError as ne:
                self.logger.error(f"NameError during exec: {ne}. This might indicate a missing import not covered by the prefix.")
                raise # Re-raise the original error after logging
            except Exception as exec_err:
                 self.logger.error(f"Error executing generated code: {exec_err}")
                 raise # Re-raise other execution errors

            # Get the interestingness function
            interestingness_function = None
            target_function_name = "calculate_interestingness"
            for name, obj in namespace.items():
                if callable(obj) and "interestingness" in name.lower():
                    interestingness_function = obj
                    break
            if not interestingness_function:
                raise ValueError("No interestingness function found in generated code")
            
            # Test the function with minimal inputs to ensure it works
            try:
                result = interestingness_function("dummy_id", KnowledgeGraph())
                self.logger.info(f"Test call succeeded, returned: {result}")
            except Exception as e:
                self.logger.warning(f"Test call to generated function failed: {e}")
                # Continue anyway, as the function might work with real inputs
            
            return interestingness_function
            
        except SyntaxError as e:
            self.logger.error(f"Syntax error in generated code: {e}")
            raise ValueError(f"Syntax error in generated interestingness function: {e}")
        except Exception as e:
            self.logger.error(f"Error creating function: {e}")
            raise ValueError(f"Error creating interestingness function: {e}")
    
    def save(self, artifact: str, output_path: Optional[str] = None) -> str:
        """
        Save the generated interestingness function to a file.
        
        Args:
            artifact: The function code to save
            output_path: Path or identifier for the output file
            
        Returns:
            Path to the saved function file
        """
        # Create a formatted timestamp for the function generation
        timestamp = datetime.now().isoformat()
        
        # Determine output path
        if output_path is None:
            # Generate a unique ID
            timestamp_short = datetime.now().strftime("%Y%m%d_%H%M%S")
            random_hash = secrets.token_hex(4)
            function_id = f"interestingness_{timestamp_short}_{random_hash}"
        elif isinstance(output_path, str) and not output_path.endswith('.py'):
            # Use the provided identifier
            function_id = output_path
        else:
            # Use the full path
            filepath = output_path
            # Ensure the directory exists
            os.makedirs(os.path.dirname(os.path.abspath(filepath)), exist_ok=True)
            # Write the file
            with open(filepath, 'w') as f:
                f.write(artifact)
            return filepath
            
        # Ensure function ID has a prefix if it doesn't already
        if not function_id.startswith("interestingness_"):
            function_id = f"interestingness_{function_id}"
            
        # Prepare the file header
        # Use the dynamically loaded imports from the YAML file
        imports_section = self._required_imports if hasattr(self, '_required_imports') else "# Imports not loaded\n"
        
        function_header = [
            '"""', # Docstring start
            f"Generated interestingness function: {function_id}",
            f"Generation time: {timestamp}",
            '"""', # Docstring end
            "",
            # Add the required imports loaded from YAML
            imports_section,
            # Keep minimal essential imports just in case YAML loading failed, but they might be redundant now
            # "import math",
            # "import numpy as np",
            # "from typing import Dict, Any, List, Set, Optional, Union",
            # "from frame.knowledge_base.knowledge_graph import KnowledgeGraph, NodeType",
            ""
        ]
        
        # Extract the function code from the artifact
        function_code = artifact
            
        # Combine header and function code
        full_function_text = '\n'.join(function_header) + '\n' + function_code
        
        # Handle the output directory path
        output_dir = self._resolve_path(self.output_dir)
        
        # Ensure the directory exists
        os.makedirs(output_dir, exist_ok=True)
        
        # Create the file path
        filename = f"{function_id}.py"
        filepath = os.path.join(output_dir, filename)
        
        # Save the file
        with open(filepath, 'w') as f:
            f.write(full_function_text)
            
        self.logger.info(f"Saved interestingness function to {filepath}")
        return filepath
    
    @classmethod
    def load(cls, artifact_path: str) -> Callable:
        """
        Load a previously generated interestingness function.
        
        Args:
            artifact_path: Path to the saved function file
            
        Returns:
            Loaded interestingness function
        """
        return cls.load_function_from_file(artifact_path)


# Utility functions that can be used outside the class

def _load_yaml_util(path):
    """Utility to load YAML with error handling."""
    try:
        with open(path, 'r') as f:
            return yaml.safe_load(f)
    except FileNotFoundError:
        logging.error(f"YAML file not found: {path}")
        raise
    except yaml.YAMLError as e:
        logging.error(f"Error parsing YAML file {path}: {e}")
        raise RuntimeError(f"Failed to parse YAML: {path}") from e
    except Exception as e:
         logging.error(f"Unexpected error loading YAML {path}: {e}")
         raise RuntimeError(f"Unexpected error loading YAML: {path}") from e

def prepare_prompt(
    primitive_categories_override: Optional[Dict[str, Any]] = None,
    template_path: str = None, # Default to None, requires caller to provide path
    additional_context: Optional[str] = None
) -> str:
    """
    Prepare a prompt for generating an interestingness function (for one-shot use).
    Loads data from the central YAML configuration specified by template_path.
    
    Args:
        primitive_categories_override: Optional dictionary to override primitives from YAML.
        template_path: Path to the prompt template YAML.
        additional_context: Optional additional context to add to the prompt.
        
    Returns:
        Formatted prompt string.
    """
    # Ensure template_path is provided
    if template_path is None:
        raise ValueError("template_path must be provided to prepare_prompt.")
        
    # Load the full config
    full_config = _load_yaml_util(template_path)
    
    # Get the one-shot prompt structure
    one_shot_template_data = full_config.get('one_shot_prompt')
    if one_shot_template_data is None:
        raise KeyError(f"Key 'one_shot_prompt' not found in {template_path}")
        
    # Get primitives data - use override if provided, otherwise from YAML
    primitive_data = primitive_categories_override if primitive_categories_override is not None else full_config.get('primitive_categories')
    if primitive_data is None:
         raise KeyError(f"Key 'primitive_categories' not found in {template_path} and no override provided.")

    # Create a temporary generator instance just to use its formatting methods
    # Pass None for paths/configs as we provide the data directly
    generator = OneShotLLMGenerator(model_config_path=None, prompt_template_path=None)
    generator.prompt_template = one_shot_template_data # Set the relevant template part
    generator._primitive_categories = primitive_data # Set the primitives data
    
    # Generate and return the prompt
    return generator.generate_prompt(
        template_data=one_shot_template_data, 
        primitive_data=primitive_data,
        additional_context=additional_context
    )

async def test_parser():
    """Test the code parser on a sample LLM response."""
    # Sample response with Python code blocks
    sample_response = """
Here's an interestingness function based on your specifications:

```python
def calculate_interestingness(entity_id: str, graph: KnowledgeGraph) -> float:
    \"\"\"
    Calculate the interestingness of an entity based on various factors.
    
    Args:
        entity_id: The ID of the entity to evaluate
        graph: The knowledge graph containing the entity
        
    Returns:
        A float between 0 and 1 representing the interestingness score
    \"\"\"
    # Get basic information about the entity
    node_type = get_entity_node_type(entity_id, graph)
    
    # Different scoring strategies based on entity type
    if node_type == 1.0:  # Concept
        return score_concept(entity_id, graph)
    elif node_type == 2.0:  # Conjecture
        return score_conjecture(entity_id, graph)
    else:
        # Default scoring for other types
        return 0.5  # Neutral score

def score_concept(entity_id: str, graph: KnowledgeGraph) -> float:
    # Implementation here
    return 0.7

def score_conjecture(entity_id: str, graph: KnowledgeGraph) -> float:
    # Implementation here
    return 0.8
```

The function uses different scoring strategies based on entity type and combines multiple factors.
"""

    # Create a generator and test the parser
    generator = OneShotLLMGenerator()
    code = generator._extract_code_from_response(sample_response)
    print(f"Extracted code:\n{code}")
    
    # Test creating a function
    try:
        fn = generator._create_function_from_code(code, "test_function")
        print("Successfully created function")
    except Exception as e:
        print(f"Error creating function: {e}")

async def main():
    """Run a test of the LLM interestingness generator."""
    # Configure logging
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger("test")
    
    # Create a generator
    generator = OneShotLLMGenerator(
        model="gpt-4o",
        temperature=1.0,
        logger=logger
    )
    
    # Generate an interestingness function
    try:
        function, path = await generator.generate(save_function=True)
        print(f"Generated function saved to: {path}")
        
        # Test the function
        result = function("test_id", KnowledgeGraph())
        print(f"Test result: {result}")
        
        # Load the function from file
        loaded_function = OneShotLLMGenerator.load(path)
        loaded_result = loaded_function("test_id", KnowledgeGraph())
        print(f"Loaded function result: {loaded_result}")
        
    except Exception as e:
        logger.error(f"Error in test: {e}")
        logger.error(traceback.format_exc())

if __name__ == "__main__":
    asyncio.run(main()) 