# Copyright 2023 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Configuration of a FunSearch experiment. Only data classes, no methods."""
import dataclasses
from typing import List, Optional, Dict, Tuple
import os


@dataclasses.dataclass(frozen=True)
class RabbitMQConfig:
    """Configuration for RabbitMQ connection.

    Attributes:
      host: The hostname of the RabbitMQ server.
      port: The port of the RabbitMQ server.
      username: Username for authentication with the RabbitMQ server.
      password: Password for authentication with the RabbitMQ server.
    """
    host: str = 'rabbitmq'
    port: int = 5672 
    username: str = 'guest' 
    password: str = 'guest' 
    vhost = "temp_1"

@dataclasses.dataclass(frozen=True)
class ProgramsDatabaseConfig:
    """Configuration of a ProgramsDatabase.

    Attributes:
        functions_per_prompt: Number of previous programs to include in current prompt.
        num_islands: Number of islands to maintain for diversity.
        reset_period: The interval (in seconds) at which the weakest islands are reset. 
                      If None, resets occur only based on the number of stored programs.  
        reset_programs: The number of stored programs after which the weakest islands are reset.
        cluster_sampling_temperature_init: Initial temperature for softmax sampling of clusters within an island.
        cluster_sampling_temperature_period: Period of linear decay of the cluster sampling temperature.
        prompts_per_batch: Batch size for processing prompts received from the database_queue.
        no_deduplication: Disable deduplication (default: False, set True to disable).
        prompt_limit: Maximum number of prompts that can be published. Once reached, no new prompts are constructed, but queued messages are still processed.
        optimal_solution_programs: Sets the number of additional programs to generate after the first optimal solution is found.
    """
    functions_per_prompt: int = 2
    num_islands: int = 10
    reset_period: Optional[int] = None
    reset_programs: int = 1200
    cluster_sampling_temperature_init: float = 0.1
    cluster_sampling_temperature_period: int = 30_000
    prompts_per_batch: int = 10
    no_deduplication=False
    prompt_limit=400_000
    optimal_solution_programs=20_000


@dataclasses.dataclass(frozen=True)
class SamplerConfig:
    """Configuration of the sampler.

    Attributes:
        prompts_per_batch: Batch size for processing prompts received from the sampler_queue.
        samples_per_prompt: Number of independently sampled program continuations to obtain for each prompt.
        temperature_period: Controls how fast the LLM's temperature decreases as more programs are registered. 
                            If None, dynamic temperature adjustment is disabled.
        temperature: Controls randomness; higher values increase diversity, lower values make outputs more deterministic.
        max_new_tokens: The maximum number of tokens the LLM can generate in response.
        top_p: Determines the range of likely tokens the model samples from, keeping only the most probable ones.
        repetition_penalty: Penalizes repetitive text; values >1 discourage repetition, while 1 disables it.
    """
    prompts_per_batch = 10
    samples_per_prompt: int = 2
    temperature_period = None
    temperature: float = 0.9444444444444444
    max_new_tokens: int = 246
    top_p: float = 0.7777777777777778 
    repetition_penalty: float = 1.222222 


@dataclasses.dataclass(frozen=True)
class EvaluatorConfig:
    """Configuration for the evaluator.

    Attributes:
        timeout: Timeout in seconds for the sandbox.
    """
    timeout: int = 30 

@dataclasses.dataclass(frozen=True)
class PromptConfig:
    """Configuration for constructing prompts.

    Attributes:
        spec_path: Path to the specification file used in the experiment.
                This file serves two purposes:
                1. Constructing the evaluation script that runs and evaluates priority functions.
                2. Defining the prompt structure when `reasoning=False` and `challenge_vtcodes=False`.

                Default case: Uses `"load_graph/baseline.txt"`, where `load_graph` preloads a graph from the specification file.
                - `load_graph` can be replaced with `construct_graph` to construct a new graph from scratch for each evaluation.
                - `baseline.txt` specifies the prompt structure, with available prompt templates including `prompt_1.txt`, `prompt_3.txt`, `prompt_4.txt`, and `prompt_5.txt`.

                If `challenge_vtcodes=True`, it changes to `"challenge_vtcodes/prompt1.txt"`, which also includes sequence overlap evaluation with VT codes.

        eval_code: If True, includes an evaluation script in the prompt (Prompt 2)

        include_nx: If True, includes the `networkx` package in the prompt.
                    Set False for `prompt_3.txt` in the `load_graph` or `construct_graph` prompt strategy.

        gpt: If True, uses GPT-4o mini. If False, defaults to StarCoder2.
             If either `reasoning` or `challenge_vtcodes` is set to True, `gpt` must be True.

        reasoning: Enables reasoning-based prompts using `reasoning_template_path`.
                   This prompt strategy guides the LLM to analyze previous priority functions,
                   evaluate their scores per test before generating an improved priority function.

        challenge_vtcodes: Enables a prompt strategy that challenges VT codes using `challenge_template_path`.
                           If set, evaluation also computes sequence overlap with VT codes.

        reasoning_template_path: Path to the template used for reasoning-based prompts (if `reasoning=True` and `gpt=True`).

        challenge_template_path: Path to the template used for VT-challenging prompts (if `challenge_vtcodes=True` and `gpt=True`).

        vt_solution: Dictionary of VT codebook sizes for reference (used if `challenge_vtcodes=True`).

        target_solutions: Dictionary mapping `(n, s)` pairs to target sizes for improved deletion-correcting codes.
                         Automatically generated based on `s_values`, `start_n`, and `end_n`.

        mode: Strategy for score reduction. Options: `'last'`, `'average'`, `'weighted'`, `'relative_difference'`.

        s_values: List of number of deletions `s` to consider in the experiment.
        start_n: List of shortest sequence lengths for each `s` value.
        end_n: List of longest sequence lengths for each `s` value.
    """
    
    spec_path: str = dataclasses.field(init=False)
    eval_code: bool = False
    include_nx: bool = True

    gpt: bool = False
    reasoning: bool = False
    challenge_vtcodes: bool = False

    reasoning_template_path: Optional[str] = dataclasses.field(init=False, default=None)
    challenge_template_path: Optional[str] = dataclasses.field(init=False, default=None)
    template_eval: str = dataclasses.field(init=False)

    vt_solution: Dict[Tuple[int, int], float] = dataclasses.field(default_factory=lambda: {
        (9, 1): 52, (10, 1): 94, (11, 1): 172, (12, 1): 316, (13, 2): 586, (14, 2): 1096
    })

    mode: str = "last"
    s_values: List[int] = dataclasses.field(default_factory=lambda: [1,2])
    start_n: List[int] = dataclasses.field(default_factory=lambda: [9,10])
    end_n: List[int] = dataclasses.field(default_factory=lambda: [11,12])

    target_solutions: Dict[Tuple[int, int], float] = dataclasses.field(init=False)

    def __post_init__(self):
        if self.reasoning or self.challenge_vtcodes:
            object.__setattr__(self, "gpt", True)

        llm_category = "gpt" if self.gpt else "StarCoder2"

        # Default prompt category
        prompt_category = "load_graph"  # Always use load_graph unless challenge_vtcodes=True
        spec_filename = "prompt_3.txt"

        if self.challenge_vtcodes:
            prompt_category = "challenge_vtcodes"
            spec_filename = "prompt1.txt"

        object.__setattr__(self, "spec_path", get_spec_path(llm_category, prompt_category, spec_filename))
        object.__setattr__(self, "template_eval", get_spec_path(llm_category, "load_graph", "without_hash.txt"))

        # Assign reasoning and challenge template paths, but do not modify spec_path
        if self.gpt:
            object.__setattr__(self, "reasoning_template_path", get_spec_path("gpt", "reasoning", "prompt1.txt"))
            object.__setattr__(self, "challenge_template_path", get_spec_path("gpt", "challenge_vtcodes", "surpass_VT.txt"))

        dynamic_target_solutions = {}
        known_targets = {
            1: {6: 10, 7: 16, 8: 30, 9: 52, 10: 94, 11: 172, 12: 317, 13: 587, 14: 1097, 15: 2049, 16: 3857},
            2: {7: 5, 8: 7, 9: 11, 10: 16, 11: 24, 12: 33, 13: 50, 14: 79, 15: 127, 16: 202}
        }

        for s, start, end in zip(self.s_values, self.start_n, self.end_n):
            if s in known_targets:
                for n in range(start, end + 1):
                    if n in known_targets[s]:
                        dynamic_target_solutions[(n, s)] = known_targets[s][n]

        object.__setattr__(self, "target_solutions", dynamic_target_solutions)

def get_spec_path(llm_category: str, prompt_category: str, filename: str) -> str:
    base_dir = os.path.abspath(os.path.dirname(__file__)) if "__file__" in globals() else os.getcwd()
    idx = base_dir.find("FunDCC")
    fundcc_base = base_dir[: idx + len("FunDCC")] if idx != -1 else base_dir
    return os.path.join(fundcc_base, "src", "fundcc", "specifications", llm_category, prompt_category, filename)

@dataclasses.dataclass 
class Config:
  """Configuration of a FunSearch experiment.

  Attributes:
    programs_database: Configuration of the database.
    rabbitmq: Configuration for RabbitMQ connection.
    sampler: Configuration of the samplers.
    evaluator: Configuration of the evaluators.
    prompt: Configuration for the prompt.
    num_samplers: Number of independent Samplers in the experiment. 
    num_evaluators: Number of independent program Evaluators in the experiment.
    num_pdb: Number of independent program databases. Currently supports only one, but this does not create a bottleneck.
  """ 
  rabbitmq: RabbitMQConfig = dataclasses.field(default_factory=RabbitMQConfig)
  programs_database: ProgramsDatabaseConfig = dataclasses.field(default_factory=ProgramsDatabaseConfig)
  sampler: SamplerConfig = dataclasses.field(default_factory=SamplerConfig) 
  evaluator: EvaluatorConfig = dataclasses.field(default_factory=EvaluatorConfig) 
  prompt: PromptConfig = dataclasses.field(default_factory=PromptConfig) 
  num_samplers: int =  1
  num_evaluators: int = 1
  num_pdb: int = 1




