# Copyright 2023 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Class for sampling new programs."""
from collections.abc import Sequence
from typing import List, Optional, Tuple, Dict, Any

import numpy as np
import logging
import time
from typing import List
import yaml # Import YAML library
import os # Import os for path joining
import re # Import re for regular expressions

# Set up module-level logger
logger = logging.getLogger(__name__)

# Assume LLMCaller provides the necessary interface
from frame.tools.llm_caller import LLMCaller  

from frame.funsearch.implementation import evaluator
from frame.funsearch.implementation import programs_database
from frame.funsearch.implementation import code_manipulation
from frame.funsearch.implementation.abstraction_library import Abstraction
from omegaconf import OmegaConf, DictConfig
import ast
import textwrap
import asyncio

# Path to the YAML configuration file relative to this script's location
# Assuming this script is in frame/funsearch/implementation
_PROMPT_YAML_PATH = os.path.join(os.path.dirname(__file__), '..', '..', 'configs', 'prompts', 'interestingness_prompt.yaml')

def _format_primitives(primitive_data: dict) -> str:
    """Formats the primitive data from YAML into a flat markdown list for the FunSearch prompt."""
    formatted_lines = [] # Initialize a list to hold the formatted lines
    for category, data in primitive_data.get('primitive_categories', {}).items():
        # Skip adding category headers and descriptions
        # primitives_string += f"\n**{category}:**\n"
        # if 'description' in data:
        #     primitives_string += f"{data['description']}\n"
        for primitive in data.get('primitives', []):
            name = primitive.get('name', 'N/A')
            full_desc = primitive.get('description', 'No description available.')
            
            # Try to extract arguments like `(args): ` from the description start
            args_match = re.match(r"^\s*(\(.*?\)):\s*(.*)", full_desc)
            if args_match:
                args = args_match.group(1).strip() # Extract e.g., (entity_id, graph)
                desc = args_match.group(2).strip() # Extract the rest of the description
                # Format with arguments inside backticks
                formatted_lines.append(f"- `{name}{args}`: {desc}")
            else:
                # Fallback if arguments format not found
                args = "" # No arguments to display
                desc = full_desc # Use the full description
                formatted_lines.append(f"- `{name}`: {desc}")
                
    return "\n".join(formatted_lines) # Join the list of lines


# --- End System Prompt --- #


class MutationSampler:
  """Node that samples program continuations (mutations) and sends them for analysis."""

  def __init__(
      self,
      database: programs_database.ProgramsDatabase,
      evaluators: Sequence[evaluator.Evaluator],
      llm_caller: LLMCaller,  # Required LLM caller
      samples_per_prompt: int,
      prompt_yaml_path: str, # Add path parameter
      abstraction_enabled: bool # Flag for addon logic
  ) -> None:
    self._database = database
    self._evaluators = evaluators
    self._llm_caller = llm_caller  # Store the LLM caller
    self._samples_per_prompt = samples_per_prompt
    self._prompt_yaml_path = prompt_yaml_path # Store path
    self._abstraction_enabled = abstraction_enabled # Store the flag
    
    # Load prompt configuration from YAML using the provided path
    try:
        with open(self._prompt_yaml_path, 'r') as f:
            self._prompt_config = yaml.safe_load(f)
        logger.info(f"Successfully loaded prompt config from {self._prompt_yaml_path}")
        
        # Load base prompt and optional addon
        self._formatted_primitives = _format_primitives(self._prompt_config)
        self._base_system_prompt_template = self._prompt_config['funsearch_system_prompt'] 
        self._addon_prompt_template = self._prompt_config.get('funsearch_abstraction_prompt_addon') # Returns None if missing
        
        # Check primitive categories exist for formatting
        if 'primitive_categories' not in self._prompt_config:
            raise KeyError(f"Key 'primitive_categories' not found in {self._prompt_yaml_path}")
            
        # --- Validate Placeholders for Nested Formatting --- 
        if '{primitives_section}' not in self._base_system_prompt_template:
             logger.error(f"Base system prompt in {self._prompt_yaml_path} is missing mandatory '{{primitives_section}}' placeholder!")
             # Potentially raise error
        if '{abstraction_section}' not in self._base_system_prompt_template:
             logger.error(f"Base system prompt in {self._prompt_yaml_path} is missing mandatory '{{abstraction_section}}' placeholder for addon approach!")
             # Potentially raise error
            
        if self._abstraction_enabled and not self._addon_prompt_template:
            logger.warning(f"Abstraction is enabled but 'funsearch_abstraction_prompt_addon' key is missing in {self._prompt_yaml_path}. Abstraction section will be empty.")
        elif self._abstraction_enabled and self._addon_prompt_template and '{abstraction_list}' not in self._addon_prompt_template:
            logger.warning(f"Abstraction is enabled but '{{abstraction_list}}' placeholder is missing in 'funsearch_abstraction_prompt_addon' in {self._prompt_yaml_path}. Formatting will fail.")
            
    except FileNotFoundError:
        logger.error(f"MutationSampler: Prompt config file not found at {self._prompt_yaml_path}")
        raise # Re-raise
    except yaml.YAMLError as e:
        logger.error(f"MutationSampler: Error parsing prompt config YAML: {e}")
        raise RuntimeError(f"Failed to parse YAML config: {self._prompt_yaml_path}") from e
    except KeyError as e:
        logger.error(f"MutationSampler: Missing required key ({e}) in prompt config YAML: {self._prompt_yaml_path}")
        raise # Re-raise
    except Exception as e:
        logger.error(f"MutationSampler: Unexpected error loading prompt config: {e}")
        raise RuntimeError(f"Unexpected error loading prompt config from {self._prompt_yaml_path}") from e
        
    logger.info(f"MutationSampler initialized with {len(evaluators)} evaluators and {samples_per_prompt} samples per prompt.")

  def generate(self, prompt: str, num_samples: int, abstraction_descriptions: Optional[str] = None) -> List[str]:
    """Generate samples using the LLM caller.
    
    Args:
        prompt: The prompt to send to the LLM
        num_samples: Number of samples to generate
        abstraction_descriptions: Optional formatted string describing available abstractions.
        
    Returns:
        List of generated code samples
    """
    import asyncio
    
    async def _call_llm() -> List[str]: # Ensure return type annotation is accurate
      # Create a unique conversation ID
      conversation_id = f"funsearch_{int(time.time())}"
      
      # Construct the final system prompt
      try:
          # Handle optional descriptions: use placeholder if None or empty
          format_args = {
               'primitives_section': self._formatted_primitives
          }

          # --- Nested Formatting Logic --- 
          abstraction_section_content: str = ""
          if self._abstraction_enabled and self._addon_prompt_template:
              # Step 1: Format the addon template
              descriptions_to_format = abstraction_descriptions if abstraction_descriptions else "  (No abstraction functions available in this island yet.)"
              try:
                  formatted_addon = self._addon_prompt_template.format(abstraction_list=descriptions_to_format)
                  abstraction_section_content = formatted_addon
              except KeyError:
                  logger.error(f"MutationSampler: Addon prompt template missing '{{abstraction_list}}'. Abstraction section will be empty.")
                  abstraction_section_content = "(Error: Could not format abstraction addon)" # Provide error in prompt
          else:
              # If abstraction disabled or addon missing, the section is empty
              abstraction_section_content = "" # Or potentially a message like "(Abstractions not enabled)"
              
          # Step 2: Format the base template using the (potentially formatted) abstraction content
          format_args['abstraction_section'] = abstraction_section_content
          final_system_prompt = self._base_system_prompt_template.format(**format_args)

      except KeyError as e:
          # Check if the error is specifically about the abstraction placeholder if it was expected
          if 'primitives_section' in str(e) or 'abstraction_section' in str(e): 
               logger.error(f"MutationSampler: Base system prompt template missing required placeholder ({e}). Formatting failed.")
          else: # Could be missing {primitives_section} or other issue
               logger.error(f"MutationSampler: System prompt template formatting error (missing key {e}?). Check YAML placeholders.")
          # Basic fallback: format with primitives only, ignore abstraction section
          try:
               final_system_prompt = self._base_system_prompt_template.format(primitives_section=self._formatted_primitives, abstraction_section="(Formatting Error)")
          except KeyError:
               # Absolute fallback if even primitives are missing
               final_system_prompt = "Error: Could not format system prompt template. Check YAML file."
          logger.error(f"Using basic fallback system prompt due to formatting error.")
      
      logger.info(f"MutationSampler: Final system prompt:\n---\n{final_system_prompt}\n---")
      
      # Call the LLM once with n=num_samples
      try:
        response = await self._llm_caller.call_model(
          conversation_id=conversation_id, # Use the base ID for the single call
          system_prompt=final_system_prompt, # Use the dynamically constructed prompt
          prompt=prompt, # User prompt with code examples
          n=num_samples # Request multiple samples
        )
        
        # Extract the generated code from all choices
        samples = []
        if hasattr(response, 'choices') and len(response.choices) > 0:
          for choice in response.choices:
            try:
              sample = choice.message.content
              samples.append(sample if sample else "") # Append empty string if content is None/empty
            except (AttributeError, TypeError):
              logger.error("Error accessing content in LLM choice. Appending empty sample.")
              logger.debug(f"Problematic choice object: {choice}")
              samples.append("")
              
          while len(samples) < num_samples:
              logger.error(f"LLM returned fewer samples ({len(samples)}) than requested ({num_samples}). Padding with empty strings.")
              raise ValueError(f"LLM returned fewer samples ({len(samples)}) than requested ({num_samples}).")
              
          # Log number of samples actually received
          logger.info(f"Received {len(samples)} samples from single LLM call.")
          
        else:
          logger.error(f"LLM response had no choices or choices attribute. Returning empty samples.")
          logger.debug(f"Raw response object: {response}")
          samples = [""] * num_samples # Return empty samples as fallback

        return samples
        
      except Exception as e:
        logger.error(f"Error calling LLM: {e}", exc_info=True)
        # Return empty samples to maintain count
        return [""] * num_samples
    
    # Run the async call in the event loop
    logger.info(f"MutationSampler: Calling LLM to generate {num_samples} samples for prompt...")
    try:
      # Start timer for LLM call
      start_time = time.time()
      samples = asyncio.run(_call_llm())
      end_time = time.time()
      logger.info(f"MutationSampler: LLM call completed in {end_time - start_time:.2f} seconds.")
      
      logger.info(f"MutationSampler: Received {len(samples)} raw samples from LLM:")
      for i, sample_content in enumerate(samples):
          logger.info(f"  Sample {i+1}:\n---\n{sample_content}...\\n---") 
      logger.info(f"MutationSampler: Received {len(samples)} samples from LLM. Dispatching to evaluators...")
      return samples
    except Exception as e:
      # Log the full traceback for better debugging
      logger.error(f"MutationSampler: Error during asyncio.run(_call_llm()) in generate: {e}", exc_info=True) 
      logger.error("Returning empty samples due to error.")
      # Return empty samples if there's an error
      return [""] * num_samples

  def sample(self, iteration: int):
    """Gets prompts, samples programs, sends them for analysis."""
    logger.info("MutationSampler: Getting prompt from database...")
    prompt = self._database.get_prompt()
    logger.info(f"MutationSampler: Calling LLM for prompt (island {prompt.island_id}, version {prompt.version_generated})...")
    # --- Log the prompt content --- 
    logger.debug(f"MutationSampler: Prompt content:\n---\n{prompt.code[:500]}...\\n---")
    # --- End log --- 
      
    # Generate samples
    samples = self.generate(
      prompt=prompt.code,
      num_samples=self._samples_per_prompt,
      abstraction_descriptions=prompt.abstraction_descriptions # Pass descriptions here
    )
      
    # Dispatch samples to evaluators
    for sample in samples:
      if not sample.strip():  # Skip empty samples
        logger.warning("Skipping empty sample")
        continue
        
      # Choose a random evaluator
      chosen_evaluator = np.random.choice(self._evaluators)
      # Pass abstraction definitions and iteration number to the evaluator
      chosen_evaluator.analyse(
        sample=sample,
        island_id=prompt.island_id,
        version_generated=prompt.version_generated,
        iteration=iteration
      )
    
    logger.info("MutationSampler: Finished dispatching samples for this prompt.")

# === Abstraction Sampler ===

class AbstractionSampler:
    """Node that samples abstractions from existing programs."""

    def __init__(
        self,
        llm_caller: LLMCaller, # Requires the shared LLM caller
        # Configuration dictionary containing prompts and settings
        abstraction_config: Dict[str, Any] 
    ):
        self._llm_caller = llm_caller
        self._config = abstraction_config # Store the config dict directly
        
        # Extract prompts and settings from the passed config dict
        self._system_prompt = self._config.get('abstraction_system_prompt')
        self._user_prompt_template = self._config.get('abstraction_user_prompt_template')
        self._llm_override_config_raw = self._config.get('llm') # Raw override config from YAML
        self._max_abstractions = self._config.get('max_abstractions_per_step', 3) # Default if not in dict

        # Validate that prompts were loaded/passed correctly
        if not self._system_prompt or not self._user_prompt_template:
            logger.error("AbstractionSampler requires 'abstraction_system_prompt' and 'abstraction_user_prompt_template' in its configuration dict.")
            # Consider raising an error instead of just warning
            raise ValueError("AbstractionSampler missing required prompt configuration.")

        logger.info("AbstractionSampler initialized successfully.")

    async def generate_abstractions(
        self, 
        top_programs: List[Tuple[float, code_manipulation.Function]], # Use correct type hint
        island_id: int,
        current_abstractions_str: str # Descriptions of existing libs
    ) -> List[Abstraction]:
        """Generates and validates new abstractions based on provided top programs.""" 
        if not top_programs:
             return []

        # 1. Format programs for prompt
        # Ensure Function object has necessary attributes (name, args, return_type, docstring, body)
        program_examples_str_parts = []
        for score, func in top_programs:
            # --- Debug Log --- 
            logger.debug(f"[AbsSampler - Island {island_id}] Formatting program: Name='{func.name}', Args='{func.args}'")
            # --- End Debug Log --- 
            part = (
                f"# Score: {score:.4f}\n" + 
                f"def {func.name}({func.args if func.args else ''}){f' -> {func.return_type}' if func.return_type else ''}:\n" + 
                f"    \"\"\"{func.docstring if func.docstring else ''}\"\"\"\n" + 
                f"{textwrap.indent(func.body if func.body else 'pass', '    ')}"
            )
            program_examples_str_parts.append(part)
        program_examples_str = "\n\n".join(program_examples_str_parts)

        # 2. Construct User Prompt
        try:
            user_prompt = self._user_prompt_template.format(
                max_abstractions=self._max_abstractions, # Use stored config value
                current_abstractions=current_abstractions_str, 
                program_examples=program_examples_str
            )
            logger.debug(f"[AbsSampler - Island {island_id}] User Prompt (length {len(user_prompt)}):\n{user_prompt[:1000]}...")
        except KeyError as e:
            logger.error(f"[AbsSampler - Island {island_id}] Missing key {e} in user prompt template. Cannot generate abstractions.")
            return []
        except Exception as e:
            logger.error(f"[AbsSampler - Island {island_id}] Error formatting user prompt: {e}. Cannot generate abstractions.")
            return []

        # 3. Call Abstraction LLM
        llm_response = await self._call_abstraction_llm(island_id, user_prompt)

        # 4. Parse and Validate Response
        if llm_response:
            logger.debug(f"[AbsSampler - Island {island_id}] Raw LLM response:\n{llm_response[:1000]}...")
            new_abstractions = self._parse_abstraction_llm_response(llm_response, island_id)
            return new_abstractions
        else:
            logger.error(f"[AbsSampler - Island {island_id}] LLM call failed or returned no response.")
            return []

    async def _call_abstraction_llm(self, island_id: int, prompt: str) -> Optional[str]:
        """Calls the LLM for abstraction generation, handling potential config overrides.""" 
        conversation_id = f"abstraction_{island_id}_{int(time.time() * 1000)}"
        override_config = None
        use_override = False
        # Access the public model_config attribute
        main_llm_cfg = getattr(self._llm_caller, 'model_config', None)
        if main_llm_cfg is None:
             logger.warning("[AbsSampler] LLMCaller instance lacks 'model_config' attribute. Cannot check for overrides.")
             # Proceed without override checks if base config missing

        # Check for overrides specified in the abstraction config passed during init
        abs_llm_override_cfg_raw = self._llm_override_config_raw # Use the stored raw config
        if abs_llm_override_cfg_raw and main_llm_cfg is not None:
            if isinstance(abs_llm_override_cfg_raw, (DictConfig, dict)):
                 override_dict = {k: v for k, v in OmegaConf.to_container(abs_llm_override_cfg_raw, resolve=True).items() if v is not None}
                 
                 if override_dict: 
                     for key, value in override_dict.items():
                         main_value = OmegaConf.select(main_llm_cfg, key, default=None)
                         if main_value is None or main_value != value:
                             use_override = True
                             break
                     
                     if use_override:
                         override_config = OmegaConf.create(override_dict)
                         logger.info(f"[AbsSampler - Island {island_id}] Using LLM override config: {OmegaConf.to_yaml(override_config)}")
                     else:
                          logger.debug(f"[AbsSampler - Island {island_id}] Override config same as main. No override applied.")
                 else:
                      logger.debug(f"[AbsSampler - Island {island_id}] No specific LLM overrides configured.")
            else:
                 logger.warning(f"[AbsSampler] Invalid type for llm override config: {type(abs_llm_override_cfg_raw)}. Ignoring overrides.")

        try:
            # --- Log Prompts --- 
            logger.info(f"[AbsSampler - Island {island_id}] Calling LLM... (Override: {use_override})")
            logger.info(f"[AbsSampler - Island {island_id}] System Prompt:\n---\n{self._system_prompt}\n---")
            logger.info(f"[AbsSampler - Island {island_id}] User Prompt:\n---{prompt}\n---")
            # --- End Log Prompts ---

            start_time = time.time()
            response = await self._llm_caller.call_model(
                conversation_id=conversation_id,
                system_prompt=self._system_prompt, 
                prompt=prompt,
                model_config_override=override_config 
            )
            end_time = time.time()
            logger.info(f"[AbsSampler - Island {island_id}] LLM call completed in {end_time - start_time:.2f} seconds.")

            # Check response and extract content
            if hasattr(response, 'choices') and len(response.choices) > 0:
                try:
                    # Attempt to access the content
                    content = response.choices[0].message.content
                    if content: # Ensure content is not None or empty
                        logger.debug(f"[AbsSampler - Island {island_id}] LLM response content (first 1000 chars):\\n{content[:1000]}...")
                        return content
                    else:
                        logger.error(f"[AbsSampler - Island {island_id}] LLM response content is empty.")
                        logger.debug(f"[AbsSampler - Island {island_id}] Raw response object: {response}")
                        return None
                except (AttributeError, IndexError, TypeError) as e:
                    # Handle cases where the expected structure is missing
                    logger.error(f"[AbsSampler - Island {island_id}] LLM response structure unexpected or missing content: {e}")
                    logger.debug(f"[AbsSampler - Island {island_id}] Raw response object: {response}")
                    return None
            else:
                # Handle cases where the response has no choices
                logger.error(f"[AbsSampler - Island {island_id}] LLM call returned no choices or empty choices list.")
                logger.debug(f"[AbsSampler - Island {island_id}] Raw response object: {response}")
                return None
        except Exception as e:
            logger.error(f"[AbsSampler - Island {island_id}] Error calling LLM: {e}", exc_info=True)
            return None

    def _parse_abstraction_llm_response(self, response_text: str, island_id: int) -> List[Abstraction]:
        """Parses the LLM response text to extract Abstraction objects from fenced code blocks.""" 
        logger.info(f"[AbsSampler - Island {island_id}] Raw LLM response received:\n{response_text}") # Log the raw response
        abstractions = []
        
        # 1. Regex to find ```python ... ``` blocks
        # Captures the content inside the fences
        block_pattern = re.compile(r"```(?:python)?\s*([\s\S]*?)```", re.DOTALL)

        for block_match in block_pattern.finditer(response_text):
            code_inside_block = block_match.group(1).strip()
            if not code_inside_block:
                logger.debug(f"[AbsSampler - Island {island_id}] Skipping empty code block.")
                continue # Skip empty blocks

            try:
                # 2. Parse the extracted block using AST
                tree = ast.parse(code_inside_block)
                
                # 3. Find the function definition (expecting one per block)
                func_def = None
                for node in ast.walk(tree):
                    if isinstance(node, ast.FunctionDef):
                        if func_def is None: # Take the first FunctionDef found
                            func_def = node
                        else:
                            # Log if multiple functions found, but proceed with the first
                            logger.warning(f"[AbsSampler - Island {island_id}] Found multiple functions in one code block. Using the first ('{func_def.name}').")
                            break 

                if func_def:
                    # 4. Extract details from the AST node
                    name = func_def.name
                    args_str = ast.unparse(func_def.args)
                    return_type_str = ast.unparse(func_def.returns) if func_def.returns else None
                    # Use ast.get_docstring to handle potential absence cleanly
                    docstring = ast.get_docstring(func_def, clean=False) 
                    description = docstring.strip() if docstring else "(No description provided)"

                    # Reconstruct signature string
                    signature_str = f"({args_str})" + (f" -> {return_type_str}" if return_type_str else "")
                    
                    # The 'code' for the abstraction is the original, validated block
                    full_code_for_abstraction = code_inside_block

                    # 5. *** Syntax Check the Code ***
                    try:
                        ast.parse(full_code_for_abstraction)
                        logger.debug(f"[AbsSampler - Island {island_id}] Syntax check passed for abstraction '{name}'.")
                    except SyntaxError as syntax_e:
                        logger.error(f"[AbsSampler - Island {island_id}] Syntax error in generated abstraction '{name}': {syntax_e}. Discarding.")
                        logger.debug(f"[AbsSampler - Island {island_id}] Invalid code block:\n{full_code_for_abstraction[:500]}...")
                        continue # Skip to the next code block
                    # --- End Syntax Check ---

                    # 6. Create Abstraction object (if syntax is valid)
                    abstraction = Abstraction(
                        name=name,
                        signature=signature_str,
                        description=description,
                        code=full_code_for_abstraction 
                    )
                    abstractions.append(abstraction)
                    logger.info(f"[AbsSampler - Island {island_id}] Successfully parsed abstraction: {name}{signature_str} from code block.")
                else:
                    logger.warning(f"[AbsSampler - Island {island_id}] No function definition found in code block:\n{code_inside_block[:200]}...")

            except SyntaxError as e:
                logger.error(f"[AbsSampler - Island {island_id}] Syntax error parsing code block: {e}")
                logger.debug(f"[AbsSampler - Island {island_id}] Problematic code block:\n{code_inside_block[:500]}...")
            except Exception as e:
                # Catch other potential errors during AST processing for this block
                logger.error(f"[AbsSampler - Island {island_id}] Error processing code block for abstraction: {e}", exc_info=True)
                logger.debug(f"[AbsSampler - Island {island_id}] Problematic code block:\n{code_inside_block[:500]}...")

        # Final check and log if no abstractions were found
        if not abstractions:
            logger.warning(f"[AbsSampler - Island {island_id}] Could not parse any valid abstractions from LLM response (checked {len(list(block_pattern.finditer(response_text)))} code blocks).")
            logger.debug(f"[AbsSampler - Island {island_id}] Raw LLM response start:\n{response_text[:1000]}...")
                
        return abstractions
