# Copyright 2023 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Class for evaluating programs proposed by the Sampler."""
import ast
from collections.abc import Sequence
import copy
from typing import Any, Optional
import sys
import threading
import time
import traceback
import logging
import re
import textwrap

# Set up module-level logger
logger = logging.getLogger(__name__)

from frame.funsearch.implementation import code_manipulation
from frame.funsearch.implementation import programs_database
from frame.funsearch.theory_builder_sandbox import TheoryBuilderSandbox


class _FunctionLineVisitor(ast.NodeVisitor):
  """Visitor that finds the start and end line numbers of a function body."""

  def __init__(self, target_function_name: str) -> None:
    self._target_function_name: str = target_function_name
    self._body_start_line: int | None = None
    self._body_end_line: int | None = None

  def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
    if node.name == self._target_function_name:
      if node.body:
        # Body starts at the first statement's line number
        self._body_start_line = node.body[0].lineno
        # Consider docstring: if the first node is a docstring, body starts after it
        if isinstance(node.body[0], ast.Expr) and isinstance(node.body[0].value, ast.Constant) and isinstance(node.body[0].value.value, str):
             if len(node.body) > 1:
                  self._body_start_line = node.body[1].lineno
             else: # Only docstring exists
                  self._body_start_line = node.end_lineno + 1 # Indicate empty body

        # Body ends at the function's end line number
        self._body_end_line = node.end_lineno
      else: # Function has no body statements (e.g., just 'pass' implicitly or explicitly)
           # Use function's definition line to indicate no real body content
           self._body_start_line = node.lineno + 1 
           self._body_end_line = node.lineno 
    self.generic_visit(node) # Continue traversal

  @property
  def body_start_line(self) -> int | None:
    return self._body_start_line

  @property
  def body_end_line(self) -> int | None:
    return self._body_end_line

def _trim_function_body(generated_code: str) -> str:
  """Extracts the body of the generated function using code_manipulation.text_to_function.

  Assumes `generated_code` contains a single, complete function definition.
  Trims surrounding whitespace before parsing.
  
  Args:
    generated_code: The string containing the function definition.
  
  Returns:
    The extracted function body as a string.
  
  Raises:
    ValueError: If parsing fails or the input is invalid.
  """
  logger.debug("_trim_function_body called using code_manipulation")
  if not generated_code or not generated_code.strip():
    logger.error("Received empty code block in _trim_function_body.")
    raise ValueError("Input code block is empty or contains only whitespace.")

  code_to_parse = generated_code.strip()
  logger.debug(f"Attempting to parse code with text_to_function:\n{code_to_parse[:200]}...")

  try:
    # Use text_to_function to parse the input string
    parsed_function = code_manipulation.text_to_function(code_to_parse)
    
    # Check if a valid body was extracted
    if parsed_function.body is None:
         # This case might be less likely if text_to_function raises errors on empty bodies
         logger.error("text_to_function succeeded but returned None for the body.")
         raise ValueError("Function body could not be extracted (returned None).")
    
    # Return the extracted body
    # Add trailing newlines to match original function's behavior? 
    # The original added '\n\n'. Let's add one newline for now.
    body = parsed_function.body
    logger.debug(f"Successfully extracted body using text_to_function:\n{body[:200]}...")
    return body + '\n' # Add one trailing newline

  except ValueError as ve:
    # text_to_function raises ValueError if it expects 1 func but gets 0 or >1
    logger.error(f"code_manipulation.text_to_function failed: {ve}")
    raise ValueError(f"Failed to parse the input as a single function: {ve}") from ve
  except SyntaxError as se:
    # ast.parse within text_to_function might raise SyntaxError
    logger.error(f"SyntaxError during parsing in text_to_function: {se}")
    raise ValueError(f"Invalid Python syntax in generated code: {se}") from se
  except Exception as e:
    # Catch any other unexpected errors from text_to_function
    logger.error(f"Unexpected error using text_to_function: {e}", exc_info=True)
    raise ValueError(f"Unexpected error parsing function code: {e}") from e

class Sandbox:
  """Sandbox for executing generated code."""

  def run(
      self,
      program: str,
      function_to_run: str,
      test_input: str,
      timeout_seconds: int,
  ) -> tuple[Any, bool]:
    """Returns `function_to_run(test_input)` and whether execution succeeded."""
    raise NotImplementedError(
        'Must provide a sandbox for executing untrusted code.')


def _calls_ancestor(program: str, function_to_evolve: str) -> bool:
  """Returns whether the generated function is calling an earlier version."""
  for name in code_manipulation.get_functions_called(program):
    # In `program` passed into this function the most recently generated
    # function has already been renamed to `function_to_evolve` (wihout the
    # suffix). Therefore any function call starting with `function_to_evolve_v`
    # is a call to an ancestor function.
    if name.startswith(f'{function_to_evolve}_v'):
      return True
  return False


class Evaluator:
  """Class that analyses functions generated by LLMs."""

  def __init__(
      self,
      database: programs_database.ProgramsDatabase,
      template: code_manipulation.Program,
      function_to_evolve: str,
      function_to_run: str,
      inputs: Sequence[Any],
      timeout_seconds: int = 30,
      sandbox: TheoryBuilderSandbox = None
  ):
    self._database = database
    self._template = template
    self._function_to_evolve = function_to_evolve
    self._function_to_run = function_to_run
    self._inputs = inputs
    self._timeout_seconds = timeout_seconds
    # Ensure a sandbox is provided
    if sandbox is None:
        raise ValueError("Evaluator requires a TheoryBuilderSandbox instance.")
    self._sandbox: TheoryBuilderSandbox = sandbox
    # logger.info(f"Evaluator initialized with sandbox: {type(self._sandbox).__name__}")

  def analyse(
      self,
      sample: str,
      island_id: int | None,
      version_generated: int | None,
      iteration: Optional[int] = None,
  ) -> None:
    """Compiles the sample and passes the function code string to the sandbox."""
    logger.debug(f"analyse called for island {island_id}, version {version_generated}, iteration {iteration}.")

    # Abstraction content is no longer handled here.
    
    try:
        # Attempt to parse the sample and get the evolved function object and code string (functions only)
        logger.debug("Attempting to parse sample and extract function code...")
        new_function, evolved_functions_str = _sample_to_program_main(
            sample, version_generated, self._template, self._function_to_evolve)
        logger.debug("Successfully extracted function object and functions-only string.")

    except ValueError as parsing_error:
        logger.error(f"Failed to parse sample or extract function code: {parsing_error}")
        logger.debug(f"Original sample (first 500 chars):\n{sample[:500]}...")
        return 
    except Exception as general_parsing_error:
        logger.error(f"Unexpected error during function extraction: {general_parsing_error}", exc_info=True)
        logger.debug(f"Original sample (first 500 chars):\n{sample[:500]}...")
        return

    try:
        # Run the evaluation using TheoryBuilderSandbox
        # Pass only the functions string, island_id, and iteration
        logger.debug(f"Running sandbox evaluation for program (island {island_id})...")
        reward, success = self._sandbox.run(
            main_function_code=evolved_functions_str, 
            island_id=island_id,
            iteration=iteration
        )
        logger.debug(f"Sandbox evaluation completed. Success: {success}, Reward: {reward}")

        # Check evaluation success and ancestor calls
        if success and not _calls_ancestor(evolved_functions_str, self._function_to_evolve):
            logger.info(f"Registering program with reward: {reward}")
            self._database.register_program(
                new_function, island_id, {'eval_score': reward}
            )
        elif not success:
            logger.warning(f"Program evaluation failed or returned non-finite reward ({reward}). Skipping registration.")
        elif _calls_ancestor(evolved_functions_str, self._function_to_evolve):
            logger.warning(f"Program calls an ancestor function, skipping registration.")

    except Exception as e:
        logger.error(f"Exception during sandbox execution or program registration: {e}")
        logger.debug(traceback.format_exc())
    finally:
        logger.debug(f"analyse finished for island {island_id}, version {version_generated}, iteration {iteration}.")

# _sample_to_program_full = _sample_to_program

def _sample_to_program_main(
    generated_code: str,
    version_generated: int | None,
    template: code_manipulation.Program,
    function_to_evolve: str,
) -> tuple[code_manipulation.Function, str]:
    """Creates a Function object and the string for only the evolved function(s)."""
    # This is essentially the body of the old _sample_to_program,
    # but *without* the part that prepends abstraction_definitions.
    # Reuse the parsing and renaming logic.
    
    logger.debug("_sample_to_program_main called.")
    
    # --- Code Extraction and Cleaning (same as before) ---
    extracted_code = ""
    cleaned_code = ""
    python_code_pattern = r"```(?:python)?\s*([\s\S]*?)```"
    matches = re.findall(python_code_pattern, generated_code)
    if matches:
        extracted_code = matches[0].strip()
        try:
             cleaned_code_step1 = bytes(extracted_code, "utf-8").decode("unicode_escape")
        except Exception: cleaned_code_step1 = extracted_code
        cleaned_code = cleaned_code_step1.replace(r'\\n', '\n').replace(r'\\t', '\t')
    else: # Fallback logic (simplified for brevity)
        logger.warning("No fenced code block found in sample, attempting fallback.")
        # Add more robust fallback if needed, for now use raw
        extracted_code = generated_code.strip()
        try: cleaned_code_step1 = bytes(extracted_code, "utf-8").decode("unicode_escape")
        except Exception: cleaned_code_step1 = extracted_code
        cleaned_code = cleaned_code_step1.replace(r'\\n', '\n').replace(r'\\t', '\t')
        if not cleaned_code.strip(): # Handle case where response is just noise
             logger.error("LLM response contained no parseable code block.")
             raise ValueError("No code block found in LLM response.")

    logger.debug(f"Final cleaned code for main function body:\n{cleaned_code}")
    body = _trim_function_body(cleaned_code)
    logger.debug(f"Extracted main function body:\n{body}")
    
    # --- Rename calls (same as before) ---
    if version_generated is not None:
        try:
             template.get_function(function_to_evolve)
             original_func_name = f'{function_to_evolve}_v{version_generated}'
             if original_func_name in body:
                  body = code_manipulation.rename_function_calls(body, original_func_name, function_to_evolve)
        except Exception as rename_e: raise ValueError(f"Failed rename: {rename_e}") from rename_e

    # --- Create Program object and Extract Function String (NO preface) ---
    logger.debug("Creating Program object and extracting functions string...")
    program = copy.deepcopy(template)
    try:
        evolved_function = program.get_function(function_to_evolve)
        evolved_function.body = body
        
        # Construct string of ONLY the functions
        functions_string = '\n'.join([str(f) for f in program.functions])

        logger.debug(f"Constructed functions-only string (len={len(functions_string)}):\n---\n{functions_string[:1000]}...\\n---")
        # Return the main evolved function object for registration, and the functions string
        return evolved_function, functions_string 
    except Exception as e:
        logger.error(f"Error creating program object/functions string: {e}", exc_info=True)
        raise ValueError(f"Failed to create program/functions string: {e}") from e
