# Copyright 2023 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""A programs database that implements the evolutionary algorithm."""
from collections.abc import Mapping, Sequence
import copy
import dataclasses
import time
from typing import Any, Dict, List, Tuple, Optional

import datetime
from collections.abc import Collection, Iterator
import numpy as np
import scipy
import re
import logging
import ast

from frame.funsearch.implementation import code_manipulation
from frame.funsearch.implementation import config as config_lib
from frame.funsearch.implementation.abstraction_library import Abstraction, AbstractionLibrary

# Set up module-level logger
logger = logging.getLogger(__name__)

Signature = tuple[float, ...]
ScoresPerTest = Mapping[Any, float]


def _softmax(logits: np.ndarray, temperature: float) -> np.ndarray:
  """Returns the tempered softmax of 1D finite `logits`."""
  if not np.all(np.isfinite(logits)):
    non_finites = set(logits[~np.isfinite(logits)])
    raise ValueError(f'`logits` contains non-finite value(s): {non_finites}')
  if not np.issubdtype(logits.dtype, np.floating):
    logits = np.array(logits, dtype=np.float32)

  result = scipy.special.softmax(logits / temperature, axis=-1)
  # Ensure that probabilities sum to 1 to prevent error in `np.random.choice`.
  index = np.argmax(result)
  result[index] = 1 - np.sum(result[0:index]) - np.sum(result[index+1:])
  return result


def _reduce_score(scores_per_test: ScoresPerTest) -> float:
  """Reduces per-test scores into a single score."""
  return scores_per_test[list(scores_per_test.keys())[-1]]


def _get_signature(scores_per_test: ScoresPerTest) -> Signature:
  """Represents test scores as a canonical signature."""
  return tuple(scores_per_test[k] for k in sorted(scores_per_test.keys()))


@dataclasses.dataclass(frozen=True)
class Prompt:
  """A prompt produced by the ProgramsDatabase, to be sent to Samplers.

  Attributes:
    code: The prompt, ending with the header of the function to be completed.
    version_generated: The function to be completed is `_v{version_generated}`.
    island_id: Identifier of the island that produced the implementations
       included in the prompt. Used to direct the newly generated implementation
       into the same island.
    abstraction_descriptions: Formatted descriptions of abstractions for the sampler prompt.
  """
  code: str
  version_generated: int
  island_id: int
  abstraction_descriptions: str


class ProgramsDatabase:
  """A collection of programs, organized as islands."""

  def __init__(
      self,
      config: config_lib.ProgramsDatabaseConfig,
      template: code_manipulation.Program,
      function_to_evolve: str,
  ) -> None:
    self._config: config_lib.ProgramsDatabaseConfig = config
    self._template: code_manipulation.Program = template
    self._function_to_evolve: str = function_to_evolve

    # Initialize empty islands.
    self._islands: list[Island] = []
    for _ in range(config.num_islands):
      self._islands.append(
          Island(template, function_to_evolve, config.functions_per_prompt,
                 config.cluster_sampling_temperature_init,
                 config.cluster_sampling_temperature_period))
    self._best_score_per_island: list[float] = (
        [-float('inf')] * config.num_islands)
    self._best_program_per_island: list[code_manipulation.Function | None] = (
        [None] * config.num_islands)
    self._best_scores_per_test_per_island: list[ScoresPerTest | None] = (
        [None] * config.num_islands)

    # Abstraction libraries, one per island
    self._abstraction_libraries: Dict[int, AbstractionLibrary] = {}
    for island_id in range(config.num_islands):
      # Pass the configured max_prompt_chars to the AbstractionLibrary constructor
      self._abstraction_libraries[island_id] = AbstractionLibrary(max_prompt_chars=config.abstraction_max_prompt_chars)

    self._last_reset_time: float = time.time()

  def get_prompt(self) -> Prompt:
    """Returns a prompt containing implementations from one chosen island."""
    island_id = np.random.randint(len(self._islands))
    island = self._islands[island_id]
    library = self._abstraction_libraries.get(island_id)
    if library is None:
        # Should not happen with proper init, but handle defensively
        logger.error(f"Critical error: Abstraction library missing for island {island_id} during get_prompt. Creating empty.")
        library = AbstractionLibrary()
        self._abstraction_libraries[island_id] = library # Ensure it exists going forward
    
    # Island's get_prompt now needs the library to format descriptions
    code, version_generated = island.get_prompt(library) 
    
    # Generate necessary description string from the library
    abstraction_descriptions = library.format_for_sampler_prompt()

    logger.debug(f"Island {island_id} Abstraction Lib Desc for prompt:\n{abstraction_descriptions}")

    return Prompt(
        code=code, 
        version_generated=version_generated, 
        island_id=island_id,
        abstraction_descriptions=abstraction_descriptions
    )

  def _register_program_in_island(
      self,
      program: code_manipulation.Function,
      island_id: int,
      scores_per_test: ScoresPerTest,
  ) -> None:
    """Registers `program` in the specified island."""
    self._islands[island_id].register_program(program, scores_per_test)
    score = _reduce_score(scores_per_test)
    if score > self._best_score_per_island[island_id]:
      self._best_program_per_island[island_id] = program
      self._best_scores_per_test_per_island[island_id] = scores_per_test
      self._best_score_per_island[island_id] = score
      logger.info('Best score of island %d increased to %s', island_id, score)

  def register_program(
      self,
      program: code_manipulation.Function,
      island_id: int | None,
      scores_per_test: ScoresPerTest,
  ) -> None:
    """Registers `program` in the database."""
    # In an asynchronous implementation we should consider the possibility of
    # registering a program on an island that had been reset after the prompt
    # was generated. Leaving that out here for simplicity.
    if island_id is None:
      # This is a program added at the beginning, so adding it to all islands.
      for island_id in range(len(self._islands)):
        self._register_program_in_island(program, island_id, scores_per_test)
    else:
      self._register_program_in_island(program, island_id, scores_per_test)

    # Check whether it is time to reset an island.
    if (time.time() - self._last_reset_time > self._config.reset_period):
      self._last_reset_time = time.time()
      self.reset_islands()

  def get_top_programs_for_island(self, island_id: int, num_to_sample: int) -> List[Tuple[float, code_manipulation.Function]]:
    """Retrieves the top N programs from an island based on score."""
    if not (0 <= island_id < len(self._islands)):
        logger.error(f"Invalid island_id {island_id} requested for top programs.")
        return []
    
    island = self._islands[island_id]
    # Let the island handle retrieving its best programs
    top_programs = island.get_top_programs(num_to_sample)
    logger.info(f"Retrieved {len(top_programs)} top programs from island {island_id} for abstraction.")
    return top_programs # List of (score, Function) tuples

  def update_abstraction_library(self, island_id: int, new_abstractions: List[Abstraction]):
    """Updates the abstraction library for a given island with new abstractions."""
    if not (0 <= island_id < self._config.num_islands):
      logger.error(f"Invalid island_id {island_id} for updating abstraction library.")
      return

    # Create a new library for the island if it doesn't exist.
    if island_id not in self._abstraction_libraries:
      logger.warning(f"Abstraction library for island {island_id} not found during update. Creating one now.")
      self._abstraction_libraries[island_id] = AbstractionLibrary(max_prompt_chars=self._config.abstraction_max_prompt_chars)
    
    library = self._abstraction_libraries[island_id]
    added_count = 0
    for abstraction in new_abstractions:
        original_name = abstraction.name
        version_counter = 1
        new_name = original_name
        
        # Check for name collision and find next available version
        while new_name in library._abstractions: # Check against existing keys (_abstractions)
            new_name = f"{original_name}_v{version_counter}"
            version_counter += 1

        # If renaming occurred, update name and code
        if new_name != original_name:
            logger.info(f"[Island {island_id}] Renaming abstraction '{original_name}' to '{new_name}' due to collision.")
            abstraction.name = new_name
            # Update the function definition in the code string using AST
            try:
                tree = ast.parse(abstraction.code)
                # Find the FunctionDef node (should be the first/only top-level function)
                for node in ast.walk(tree):
                    if isinstance(node, ast.FunctionDef):
                        node.name = new_name # Rename the AST node
                        break # Assume only one function def per abstraction code
                # Regenerate the code string from the modified AST
                abstraction.code = ast.unparse(tree)
                logger.debug(f"[Island {island_id}] Updated code for renamed abstraction '{new_name}'.")
            except Exception as parse_rename_e:
                logger.error(f"[Island {island_id}] Failed to parse/rename code for abstraction '{original_name}' -> '{new_name}': {parse_rename_e}. Skipping addition.")
                continue # Skip adding this abstraction if renaming code fails
        
        # Add the (potentially renamed) abstraction to the library object
        if library.add_abstraction(abstraction):
            added_count += 1
                
    logger.info(f"Added {added_count} new/renamed abstractions to library for island {island_id}.")

  def get_abstraction_library_content(self, island_id: int, include_imports: bool = True) -> Optional[str]:
      """Gets the abstraction library content as a string for the specified island.
      
      Args:
          island_id: The ID of the island.
          include_imports: Whether the content should include standard imports.

      Returns:
          The file content string, or None if the island_id is invalid or library not found.
      """
      if island_id is None or island_id not in self._abstraction_libraries:
        logger.warning(f"Cannot get abstraction library content: Island ID {island_id} is invalid or library not found.")
        # Return a default (empty or with minimal necessary imports if sandbox expects a file)
        # For now, just return an empty string if library not found.
        # Create a temporary empty library to get default imports if needed.
        temp_lib = AbstractionLibrary(max_prompt_chars=self._config.abstraction_max_prompt_chars) # Use configured default
        return temp_lib.get_definitions_file_content(include_imports=include_imports)
      
      return self._abstraction_libraries[island_id].get_definitions_file_content(include_imports=include_imports)

  def reset_islands(self) -> None:
    """Resets the weaker half of islands."""
    # We sort best scores after adding minor noise to break ties.
    indices_sorted_by_score: np.ndarray = np.argsort(
        self._best_score_per_island +
        np.random.randn(len(self._best_score_per_island)) * 1e-6)
    num_islands_to_reset = self._config.num_islands // 2
    reset_islands_ids = indices_sorted_by_score[:num_islands_to_reset]
    keep_islands_ids = indices_sorted_by_score[num_islands_to_reset:]
    for island_id in reset_islands_ids:
      self._islands[island_id] = Island(
          self._template,
          self._function_to_evolve,
          self._config.functions_per_prompt,
          self._config.cluster_sampling_temperature_init,
          self._config.cluster_sampling_temperature_period)
      self._best_score_per_island[island_id] = -float('inf')
      self._best_program_per_island[island_id] = None # Also reset best program
      self._best_scores_per_test_per_island[island_id] = None # Also reset best scores
      founder_island_id = np.random.choice(keep_islands_ids)
      founder = self._best_program_per_island[founder_island_id]
      founder_scores = self._best_scores_per_test_per_island[founder_island_id]
      
      # Copy founder's abstraction library
      if founder_island_id in self._abstraction_libraries:
          self._abstraction_libraries[island_id] = copy.deepcopy(self._abstraction_libraries[founder_island_id])
          logger.info(f"Island {island_id} reset. Copied abstraction library (size {len(self._abstraction_libraries[island_id])}) from founder island {founder_island_id}.")
      else:
          # Should not happen if initialized correctly, but handle defensively
          self._abstraction_libraries[island_id] = AbstractionLibrary() 
          logger.warning(f"Founder island {founder_island_id} had no abstraction library during reset of island {island_id}. Initialized empty library.")
           
      # Register founder program *after* lib copy
      if founder is not None and founder_scores is not None:
          self._register_program_in_island(founder, island_id, founder_scores)
      else:
          logger.warning(f"Founder island {founder_island_id} had no best program/scores to seed reset island {island_id}.")


class Island:
  """A sub-population of the programs database."""

  def __init__(
      self,
      template: code_manipulation.Program,
      function_to_evolve: str,
      functions_per_prompt: int,
      cluster_sampling_temperature_init: float,
      cluster_sampling_temperature_period: int,
  ) -> None:
    self._template: code_manipulation.Program = template
    self._function_to_evolve: str = function_to_evolve
    self._functions_per_prompt: int = functions_per_prompt
    self._cluster_sampling_temperature_init = cluster_sampling_temperature_init
    self._cluster_sampling_temperature_period = (
        cluster_sampling_temperature_period)

    self._clusters: dict[Signature, Cluster] = {}
    self._num_programs: int = 0

  def register_program(
      self,
      program: code_manipulation.Function,
      scores_per_test: ScoresPerTest,
  ) -> None:
    """Stores a program on this island, in its appropriate cluster."""
    signature = _get_signature(scores_per_test)
    if signature not in self._clusters:
      score = _reduce_score(scores_per_test)
      self._clusters[signature] = Cluster(score, program)
    else:
      self._clusters[signature].register_program(program)
    self._num_programs += 1

  def get_prompt(self, library: AbstractionLibrary) -> tuple[str, int]:
    """Constructs a prompt containing functions from this island."""
    signatures = list(self._clusters.keys())
    cluster_scores = np.array(
        [self._clusters[signature].score for signature in signatures])

    # Convert scores to probabilities using softmax with temperature schedule.
    period = self._cluster_sampling_temperature_period
    temperature = self._cluster_sampling_temperature_init * (
        1 - (self._num_programs % period) / period)
    probabilities = _softmax(cluster_scores, temperature)

    # At the beginning of an experiment when we have few clusters, place fewer
    # programs into the prompt.
    functions_per_prompt = min(len(self._clusters), self._functions_per_prompt)

    idx = np.random.choice(
        len(signatures), size=functions_per_prompt, p=probabilities)
    chosen_signatures = [signatures[i] for i in idx]
    implementations = []
    scores = []
    for signature in chosen_signatures:
      cluster = self._clusters[signature]
      implementations.append(cluster.sample_program())
      scores.append(cluster.score)

    indices = np.argsort(scores)
    sorted_implementations = [implementations[i] for i in indices]
    version_generated = len(sorted_implementations) + 1
    # Pass library descriptions to _generate_prompt
    prompt_code = self._generate_prompt(sorted_implementations, library)
    return prompt_code, version_generated

  def _generate_prompt(
      self,
      implementations: Sequence[code_manipulation.Function],
      library: AbstractionLibrary
  ) -> str:
    """Creates a prompt containing a sequence of function `implementations`."""
    implementations = copy.deepcopy(implementations)  # We will mutate these.

    # Format the names and docstrings of functions to be included in the prompt.
    versioned_functions: list[code_manipulation.Function] = []
    for i, implementation in enumerate(implementations):
      new_function_name = f'{self._function_to_evolve}_v{i}'
      implementation.name = new_function_name
      # Update the docstring for all subsequent functions after `_v0`.
      if i >= 1:
        implementation.docstring = (
            f'Improved version of `{self._function_to_evolve}_v{i - 1}`.')
      # If the function is recursive, replace calls to itself with its new name.
      implementation = code_manipulation.rename_function_calls(
          str(implementation), self._function_to_evolve, new_function_name)
      versioned_functions.append(
          code_manipulation.text_to_function(implementation))

    # Create the header of the function to be generated by the LLM.
    next_version = len(implementations)
    new_function_name = f'{self._function_to_evolve}_v{next_version}'
    header = dataclasses.replace(
        implementations[-1],
        name=new_function_name,
        body='',
        docstring=('Improved version of '
                   f'`{self._function_to_evolve}_v{next_version - 1}`.'),
    )
    versioned_functions.append(header)

    # *** NEW: Prepare Abstraction Descriptions ***
    abstraction_descriptions = library.format_for_sampler_prompt()
    logger.debug(f"Adding abstraction descriptions to prompt:\n{abstraction_descriptions}")

    # Replace functions in the template with the list constructed here.
    logger.debug(f"Inside Island._generate_prompt, template preface is: {self._template.preface}")
    # Construct the final prompt string manually
    prompt_string = self._template.preface + "\\n\\n" if self._template.preface else ""
    prompt_string += "# Available Abstraction Functions (if any):\\n"
    prompt_string += abstraction_descriptions + "\\n\\n"
    prompt_string += "# Example Implementations:\\n"
    prompt_string += '\\n'.join([str(f) for f in versioned_functions[:-1]]) # All except the header
    prompt_string += "\\n# Function to Implement:\\n"
    prompt_string += str(versioned_functions[-1]) # Just the header definition
    
    # Log the beginning of the final prompt string for verification
    logger.debug(f"Generated prompt string start:\n{prompt_string[:500]}...")
    return prompt_string

  def get_top_programs(self, num_to_sample: int) -> List[Tuple[float, code_manipulation.Function]]:
    """Returns the top N programs from this island based on cluster score."""
    if not self._clusters:
        return []

    # Sort clusters by score (descending)
    # item[0] is signature, item[1] is Cluster object
    sorted_clusters = sorted(self._clusters.items(), key=lambda item: item[1].score, reverse=True)

    top_programs = []
    count = 0
    # Iterate through sorted clusters and sample one program from each until limit reached
    for signature, cluster in sorted_clusters:
        if count >= num_to_sample:
            break
        # Sample one representative program from the top cluster
        try:
            program = cluster.sample_program() 
            top_programs.append((cluster.score, program))
            count += 1
        except Exception as e:
            # Log if sampling fails for some reason (e.g., empty cluster unexpectedly)
            logger.error(f"Failed to sample program from cluster with score {cluster.score}: {e}", exc_info=True)
            
    return top_programs


class Cluster:
  """A cluster of programs on the same island and with the same Signature."""

  def __init__(self, score: float, implementation: code_manipulation.Function):
    self._score = score
    self._programs: list[code_manipulation.Function] = [implementation]
    self._lengths: list[int] = [len(str(implementation))]

  @property
  def score(self) -> float:
    """Reduced score of the signature that this cluster represents."""
    return self._score

  def register_program(self, program: code_manipulation.Function) -> None:
    """Adds `program` to the cluster."""
    self._programs.append(program)
    self._lengths.append(len(str(program)))

  def sample_program(self) -> code_manipulation.Function:
    """Samples a program, giving higher probability to shorther programs."""
    normalized_lengths = (np.array(self._lengths) - min(self._lengths)) / (
        max(self._lengths) + 1e-6)
    probabilities = _softmax(-normalized_lengths, temperature=1.0)
    return np.random.choice(self._programs, p=probabilities)
