import wandb
import random
import numpy as np
import openai
from llm_utils import chat_completion_rl
import re
from torch.multiprocessing import Queue
import json
from copy import deepcopy
from llm_utils import get_llm_config
from llm_utils import setup_chat_rate_limiter, chat_completion_rl, get_llm_config, num_tokens_consumed_by_chat_request, get_model_max_tokens, embedding_rl, pretty_print_chat_messages
from utils.llm_tools import process_function_call_and_return_message, function_definition_list_factory, available_functions_factory, hash_messages, process_functions_into_function_names, clean_string, detect_cycles
from summarizer_and_modifier import add_line_numbers, write_files_from_dict, load_code_files_into_dict
from openai.error import InvalidRequestError
from pathlib import Path
from utils import prompts
# from log_evaluator import count_errors_from_file_dict, run_tests_from_file_dict
import pandas as pd
import os
import traceback
import re
from utils.models import DyNODEModel, RNNModel, TransformerModel
import torch
import time
from torch import optim
import pysindy as ps
from tqdm import tqdm
from pysindy import SINDy
from pysindy.optimizers import STLSQ, SR3, SSR, FROLS
from pysindy.differentiation import FiniteDifference, SmoothedFiniteDifference, SpectralDerivative
from pysindy.feature_library import PolynomialLibrary
from utils import gp_method

torch.set_float32_matmul_precision('high')


def get_code_from_message(message):
    match = re.search(r'```python\n(.*?)\n```', message, re.DOTALL)
    if match:
        code = match.group(1)
    else:
        match = re.search(r'``` python\n(.*?)\n```', message, re.DOTALL)
        if match:
            code = match.group(1)
        else:
            match = re.search(r'```(.*?)```', message, re.DOTALL)
            if match:
                code = match.group(1)
            else:
                code = None
    return code


def extract_failed_tests(test_output):
    # Search for the failed tests count in the test output string
    match = re.search(r'(\d+) failed', test_output)
    if match:
        # Extract the number of failed tests
        failed_tests_count = int(match.group(1))
    else:
        # Default to 10 if no match is found
        failed_tests_count = 10

    return failed_tests_count

def initialize_agent(method_name, env, config, rate_limiter, wandb_project_name=None, logger=None):
    # Initialize Weights & Biases if a project name is provided
    if wandb_project_name:
        wandb.init(project=wandb_project_name, config=config)

    # Depending on the method name, initialize the agent
    if "NSDT" in method_name:
        agent = NSDT(env, config, logger, rate_limiter, method_name=method_name)
    elif method_name == "ZeroShot":
        agent = ZeroShot(env, config, logger, rate_limiter, optimize_params=False, name=method_name)
    elif method_name == "ZeroOptim":
        agent = ZeroShot(env, config, logger, rate_limiter, optimize_params=True, name=method_name)
    else:
        raise ValueError(f"Unknown method name: {method_name}")

    return agent

class Agent:
    def __init__(self, env, config, logger, rate_limiter):
        self.env = env
        self.config = config
        self.seed_value = None
        self.logger = logger
        self.rate_limiter = rate_limiter

    def run(self, state):
        raise NotImplementedError("Agents must implement a run method.")

    def seed(self, seed_value):
        self.seed_value = seed_value
        random.seed(seed_value)
        np.random.seed(seed_value)
    
    def get_llm_config(self):
        return get_llm_config(self.config, self.logger, self.name, self.rate_limiter)

class NSDT(Agent):
    def __init__(self, env, config, logger, rate_limiter, method_name=''):
        super().__init__(env, config, logger, rate_limiter)
        self.name = method_name
        self.load_from_checkpoint = ''
        self.replay_llm_responses_path = ''
        # self.replay_llm_responses_path = 'logs/run-20240127-155404_NSDT-ZeroShot-ZeroOptim_Cancer_0_3-runs_log_MAIN_TABLE/Cancer/0/NSDT_llm_responses.json'
        self.replay_llm_responses_path_index = 0
        self.responses = []
        self.message_hash_same_increase_temperature = 0
        self.step_idx = 0
        self.max_re_tries = 30
        self.re_tries = 0
        self.max_iterations = 10
        self.executor_retry_attempts = getattr(self.config.run, "executor_retry_attempts", 5)
        self.save_dir = f"/project/bii_nssac/people/hht9zt/Code_generated/PNAS/Model_validation_outputs_quantativeAnalysis"
        os.makedirs(self.save_dir, exist_ok=True)

        if self.load_from_checkpoint:
            with open(self.load_from_checkpoint, 'r') as f:
                data = json.load(f)
            self.simulator_code_dict = data['simulator_code_dict']
            self.steps = data['steps']
            self.step = data['step']
            self.meta_messages = data['meta_messages']
            self.responses = data['responses']
            self.message_hash = data['message_hash']
        else:
            self.simulator_code_dict = {}
            self.steps = []
            self.step = None
            self.meta_messages = []
            self.messages = []
            self.responses = []
            self.message_hash = hash_messages([])
        
        self.folder_path = f"{self.config.run.log_path.split('.txt')[0]}/{self.env.env_name}/{self.env.seed}/"
        Path(self.folder_path).mkdir(parents=True, exist_ok=True)
        self.max_tokens = get_model_max_tokens(config)
        self.functions = function_definition_list_factory()
        self.system_message = {"role": "system", "content": prompts.system_prompt()}

    def print_dialog(self, messages, response_msg=False):
        num_tokens = num_tokens_consumed_by_chat_request(messages=messages, functions=self.functions)
        pretty_print_chat_messages(messages, num_tokens, self.max_tokens, logger=self.logger, response_msg=response_msg, step_idx=self.step_idx, total_steps=len(self.steps), max_re_tries=self.max_re_tries, re_tries=
self.re_tries)

    def save_agent_state(self, messages, beginning_step=''):
        data_to_save = {'messages': messages,
                        'simulator_code_dict': self.simulator_code_dict,
                        'steps': self.steps,
                        'step': self.step,
                        'meta_messages': self.meta_messages,
                        'messages': self.messages,
                        'message_hash': self.message_hash,
                        }
        if not beginning_step:
            path = f'{self.folder_path}current_{self.name}_state.json' 
        else:
            path = f'{self.folder_path}NeurosymbolicLLMAgent_state_beginning_step_{self.step_idx}.json'
        with open(path, 'w') as f:
            json.dump(data_to_save, f)

    def get_llm_response(self, messages, max_tokens=None, n=1, print_=True):
        if print_:
            self.print_dialog(messages)
        self.save_agent_state(messages)
        llm_config = self.get_llm_config()
        if max_tokens is not None:
            llm_config['max_tokens'] = max_tokens
        llm_config['messages'] = messages
        if n is not None:
            llm_config['n'] = n
        tmp_messages = [clean_string(str(msg)) for msg in messages]
        if detect_cycles(tmp_messages):
            self.message_hash_same_increase_temperature += 0.4
            if self.message_hash_same_increase_temperature >= 1:
                self.message_hash_same_increase_temperature = 1
            self.logger.info(f'[Increasing LLM temperature to {self.message_hash_same_increase_temperature}]')
        else:
            if self.message_hash_same_increase_temperature > 0:
                self.logger.info(f'[Annealing LLM temperature to {self.message_hash_same_increase_temperature}]')
                self.message_hash_same_increase_temperature -= 0.1
                if self.message_hash_same_increase_temperature <= 0:
                    self.message_hash_same_increase_temperature = 0
        llm_config['temperature'] = self.message_hash_same_increase_temperature
        llm_config['functions'] = self.functions
        if messages[-1].get('function_call'):
            llm_config['function_call'] = messages[-1]['function_call']
            del(messages[-1]['function_call'])
        if self.replay_llm_responses_path:
            with open(self.replay_llm_responses_path, 'r') as f:
                responses = json.load(f)
            response = responses[self.replay_llm_responses_path_index]
            self.replay_llm_responses_path_index += 1
            if 'error' in response:
                raise InvalidRequestError(response['error'], '')
        else:
            try:
                # Check number of tokens
                num_tokens = num_tokens_consumed_by_chat_request(messages=messages, functions=self.functions)
                if num_tokens > self.max_tokens:
                    raise InvalidRequestError('InvalidRequestError', 'SelfGeneratedErrorOverTokenLimit')
                response = chat_completion_rl(**llm_config)
                self.responses.append(response)
                with open(f'{self.folder_path}{self.name}_llm_responses.json', 'w') as f:
                    json.dump(self.responses, f)
            except openai.error.InvalidRequestError as e:
                self.responses.append({'error': 'InvalidRequestError'})
                self.logger.error('Error: InvalidRequestError')
                self.logger.error(traceback.format_exc())
                self.logger.info("Error:", e.__dict__)  # or use a logging framework
                raise e
        if len(response['choices']) > 1 or n is not None:
            message_responses = []
            for choice in response['choices']:
                message_response = choice["message"]
                if not message_response.get('content'):
                    message_response['content'] = None
                message_responses.append(message_response)
            # self.print_dialog([message_response], response_msg=True)
            if print_:
                self.print_dialog(message_responses, response_msg=True)
            return message_responses
        else:
            message_response = response["choices"][0]["message"]
            if not message_response.get('content'):
                message_response['content'] = None
            if print_:
                self.print_dialog([message_response], response_msg=True)
            return message_response

    def get_function_names_as_str(self):
        fns = process_functions_into_function_names(self.functions)
        return ', '.join([f'`{fn}`'for fn in fns])

    def run(self, state=''):
        # try:
        return self._run(state)
        # except Exception as e:
        #     self.logger.error('Error in LLMatic.run()')
        #     self.logger.error(e)
        #     self.logger.error(traceback.format_exc())
        #     self.save_agent_state(self.sub_messages)
        #     write_files_from_dict(self.file_dict, base_dir=f'{self.folder_path}/{self.name}')
        #     raise e


    def get_llm_response_with_retries(self, messages, n=None, print_=True):
            has_returned_successfully = False
            while not has_returned_successfully:
                try:
                    response_message = self.get_llm_response(messages, n=n, print_=print_)
                    has_returned_successfully = True
                except InvalidRequestError as e:
                    # Calculate exactly where the token limit overflowed, and undo messages till just before it overflowed.
                    while (self.max_tokens - num_tokens_consumed_by_chat_request(messages=messages, functions=self.functions)) < 700:
                        messages.pop(3)
                    self.re_tries += 1
                    if self.re_tries > self.max_re_tries:
                        self.logger.warning(f'[WARNING] Maximum re-tries reached: {self.re_tries}/{self.max_re_tries}, exiting run!')
                        raise ValueError(f'[ERROR] Maximum re-tries reached: {self.re_tries}/{self.max_re_tries}, stopping run.')
            return response_message


    #### ADDED FOR ERROR RECOVERY 

    def _normalize_response_message(self, response_message):
        if isinstance(response_message, list):
            if len(response_message) == 0:
                raise ValueError("Received empty response message list from LLM.")
            return response_message[0]
        return response_message

    def execute_function_call_with_llm_retries(self, response_message, base_messages):
        """
        Try to execute user code and, on failure, feed the error back to the LLM up to the configured number of times.
        """
        normalized_response = self._normalize_response_message(response_message)
        attempt_number = 1
        last_error = None
        messages_for_retry = deepcopy(base_messages)
        messages_for_retry.append(normalized_response)

        while attempt_number <= self.executor_retry_attempts:
            try:
                return process_function_call_and_return_message(
                    normalized_response["function_call"],
                    self.simulator_code_dict,
                    env=self.env,
                    functions=self.functions,
                    config=self.config,
                    logger=self.logger
                )
            except Exception as exc:
                last_error = exc
                if self.logger:
                    self.logger.warning(f"ERROR WHILE RUNNING THE CODE ------> [{self.name} | Executor] Attempt {attempt_number}/{self.executor_retry_attempts} failed: {exc}")

                if attempt_number >= self.executor_retry_attempts:
                    break

                error_prompt = prompts.execution_error_feedback_prompt(
                    self.env.env_name,
                    str(exc),
                    attempt_number,
                    self.executor_retry_attempts
                )
                messages_for_retry.append({
                    "role": "user",
                    "content": error_prompt,
                    "function_call": {"name": 'complete_StateDifferential_code'}
                })
                retry_response = self.get_llm_response_with_retries(messages_for_retry, n=1)
                normalized_response = self._normalize_response_message(retry_response)
                messages_for_retry.append(normalized_response)
                attempt_number += 1

        error_hint = str(last_error) if last_error else "Unknown executor error."
        raise RuntimeError(
            f"Unable to produce executable code after {self.executor_retry_attempts} attempts. "
            f"Please modify the prompt to provide missing constraints. Last error: {error_hint}"
        ) from last_error

        


    def generate_reflection_competition_for_generation_dict(self, generation_dict):
        if self.env.env_name == 'Cancer' or self.env.env_name == 'Cancer-ood' or self.env.env_name == 'Cancer-iid' or 'Cancer-random' in self.env.env_name:
            val_loss_per_dim_str = f"(Where the val loss per dimension is tumor_volume val loss: {generation_dict['code_dict']['loss_per_dim_dict']['tumor_volume_val_loss']:.3g}, chemotherapy_drug_concentration val loss: {generation_dict['code_dict']['loss_per_dim_dict']['chemo_drug_concentration_val_loss']:.3g})"
        elif self.env.env_name == 'Covid-scenario' or self.env.env_name == 'MRSA-meta' or self.env.env_name == 'PNAS' :
            import torch
            val = generation_dict['code_dict']['loss_per_dim_dict']['infected']
            if isinstance(val, list):
                val = sum(v.item() if torch.is_tensor(v) else float(v) for v in val) / len(val)
            elif torch.is_tensor(val):
                val = val.item()
            else:
                val = float(val)
            generation_dict['code_dict']['loss_per_dim_dict']['infected'] = val
            
            val_loss_per_dim_str = f"(Where the val loss per dimension is infected val loss: {val:.3g})."

            print(generation_dict['code_dict'])
        
        else:
            raise ValueError(f'Unknown env name: {self.env.env_name}')
        completion = f"""
Val Loss: {generation_dict['code_dict']['val_loss']:.3g} {val_loss_per_dim_str} Iteration: {generation_dict['iteration']}
###
```
{generation_dict['code_dict']['StateDifferential_code']}
"""
        ckpt_dir = os.path.join(
            "checkpoints",
            self.env.env_name,
            generation_dict.get("method_name", "NSDT")
        )
        os.makedirs(ckpt_dir, exist_ok=True)
        
        checkpoint = {
            "iteration": generation_dict["iteration"],
            "val_loss": generation_dict["code_dict"]["val_loss"],
            "loss_per_dim": generation_dict["code_dict"]["loss_per_dim_dict"],
            "sc_output": generation_dict['code_dict']['sc_output'],
            "optimized_parameters": generation_dict["code_dict"]["optimized_parameters"],
            "env_name": self.env.env_name,
        }
        
        ckpt_path = os.path.join(
            ckpt_dir, f"iter_{generation_dict['iteration']:04d}.pt"
        )
        torch.save(checkpoint, ckpt_path)
        print(f"[Checkpoint saved] {ckpt_path}")
        return completion

    def generate_reflection_prompt_with_group(self, generation_dicts, history_generation, iteration, history_best_generation):
            generation_dicts = sorted(generation_dicts, key=lambda d: d['fitness'], reverse=True)
            history_best_completions = []

            for idx, past_completion in enumerate(history_best_generation):
                history_best_completions.append(f"Iteration {idx}. Best Val Loss: {past_completion['code_dict']['val_loss']}. Model description: {past_completion['code_dict']['model_description']}")
            history_best_completions_str = '\n'.join(history_best_completions)

            completions = []
            for generation_dict in generation_dicts:
                completions.append(self.generate_reflection_competition_for_generation_dict(generation_dict))
            completions = '\n'.join(completions)
            if self.name == 'NSDT' or self.name == 'NSDT-no-memory':
                return f"""
You generated the following code completions, which then had their parameters optimized to the training dataset. Please reflect on how you can improve the code to minimize the validation loss to 1e-6 or less. The code examples are delineated by ###.

Here are your previous iterations the best programs generated. Use it to see if you have exhausted white box models, i.e. when a white box model repeats with the same val loss and then only add black box models to the white box models:```
{history_best_completions_str}
```

Here are the top code completions so far that you have generated, sorted for the lowest validation loss last:```
{completions}
```

Please do the following:

1. Review the code, determine whether it's correct or not. If not, give steps to make it correct in the next iteration, and also check whether it has any negative disease parameters. DO NOT ADVISE TO FORCEFULLY MAKE ANY PARAMS NON-NEGATIVE USING RELU OR SOFTPLUS.
2. Do Not make any changes in the flow.
3. Please reflect on how you can improve the code to fit the dataset as accurately as possible, and be interpretable. Think step-by-step. Provide only actionable feedback. Do not write out the code. Where applicable, use the values of the optimized parameters to reason how the code can be improved to fit the dataset as accurately as possible. This is for generating new code for the next iteration {iteration} out of {self.config.run.generations}.

4. ***Modeling Assumptions and Verifying steps*** 
- Summarize key assumptions common to all scenarios.
- List concrete verification steps for the generated model

5. *** Constraints for modeling *****
- List explicit constraints implied by the input scenarios that must be enforced in future models.

6. ***Scenario Output Validation***
Using the scenario outputs below:
```{generation_dict['code_dict']['sc_output']}```

7. Validate the model Output for all scenarios given below:
    a. Explain how the outputs validate each scenario, or why they fail to do so.
    b. Specify what should be fixed next if validation fails.
    c. Describe how scenario outputs should differ across each scenario (e.g., peaks, growth rates), and what is shown in the result.

Finally, make a summary table, containing the columns: Scenario, Should differ by, What to check in output, What is in the output, is consistent.

"""
            else:
                raise NotImplementedError

    def _run(self, state=''):
        self.messages = [self.system_message]
        initial_prompt = prompts.first_task_prompt(env_name=self.env.env_name, generations=self.config.run.generations)
        print(initial_prompt)
        print('')
        logs = []
        generation_id = 0
        n = 1
        self.logger.info(f'[Running generation 0] {self.name} | {self.env.env_name} | {self.env.seed} | Sampling n={n} keep_top_samples')

        self.messages.append({"role": "user", "content": initial_prompt, 'function_call': {"name": 'complete_StateDifferential_code'}})
        self.save_agent_state(self.messages)
        self.max_re_tries = 30
        self.re_tries = 0
        response_messages = self.get_llm_response_with_retries(self.messages, n=n) # Code message
        generation_dicts = []
        history_generation = []
        history_best_generation = []
        computed_funcs = set()
        for response_message in response_messages:
            code_json = json.dumps(response_message["function_call"])
            if code_json not in computed_funcs:
                function_return_message, code_dict, has_success = self.execute_function_call_with_llm_retries(response_message, self.messages)
                generation_dict = {'function_return_message': deepcopy(function_return_message), 'iteration': generation_id, 'code_dict': deepcopy(code_dict), 'has_success': deepcopy(has_success), 'fitness': deepcopy(code_dict['val_loss']), 'code_string': code_json}
                    
                generation_dicts.append(generation_dict)
                history_generation.append(generation_dict)
                computed_funcs.add(json.dumps(response_message["function_call"]))
            else:
                generation_dict = deepcopy([generation_dict for generation_dict in history_generation if generation_dict['code_string'] == code_json][0])
                generation_dict['iteration'] = generation_id
                history_generation.append(generation_dict) 
        



        
        ### CODE SAVING-------------------------
        filename = f"Model_Iteration{generation_id+1}.py"
        filepath = os.path.join(self.save_dir, filename)
        with open(filepath, "w") as f:
            f.write(generation_dict['code_dict']['StateDifferential_code'])
        self.logger.info(f"[INFO] Saved simulator code to: {filepath}")

        generation_dicts = sorted(generation_dicts, key=lambda d: d['fitness'])
        mean_fitness = np.mean([generation_dict['fitness'] for generation_dict in generation_dicts])
        fitnesses = [generation_dict['fitness'] for generation_dict in generation_dicts]
        num_programs = len(generation_dicts)
        top_fitness = generation_dicts[0]['fitness']
        history_best_generation.append(deepcopy(generation_dicts[0]))
        self.logger.info(f"[{self.name} | {self.env.env_name} | {self.env.seed}][Generation {generation_id}] | Top Fitness: {top_fitness} | Num Programs: {num_programs} | Mean Fitness: {mean_fitness} | Fitnesses: {fitnesses} | Current Gen Val Loss: {generation_dict['fitness']}")
        result = {'method': self.name, 'env_name': self.env.env_name, 'seed': self.env.seed, 'generation': generation_id, 'top_fitness': top_fitness, 'num_programs': num_programs, 'mean_fitness': mean_fitness, 'fitnesses': fitnesses, 'current_gen_val_loss': generation_dict['fitness']}
        self.logger.info(f"[{self.name} | {self.env.env_name} | {self.env.seed}][GEN RESULT] {result}")
        

        best_fitness = float('inf')  # Initialize with a very high value
        patience_counter = 0  # Counter for tracking patience

        for generation_id in range(1, self.config.run.generations):
            generation_dicts = sorted(generation_dicts, key=lambda d: d['fitness'])
            num_programs = len(generation_dicts)
            sub_group_programs = n
            reflection_prompt = self.generate_reflection_prompt_with_group(generation_dicts=generation_dicts, history_generation=history_generation, iteration=generation_id, history_best_generation=history_best_generation)
            messages_i = deepcopy(self.messages)
            if self.name == 'NSDT' or self.name == 'NSDT-no-memory':
                messages_i.append({"role": "user", "content": reflection_prompt, 'function_call': 'none'})
                response_message = self.get_llm_response_with_retries(messages_i, print_=False) # Reflection message
                messages_i.append(response_message)
                try:
                    filename = f"Validation_feedback_{generation_id+1}.txt"
                    filepath = os.path.join(self.save_dir, filename)
                    with open(filepath, "w") as f:
                        f.write(response_message.get("content", ""))
                except:
                    self.logger.info('Couldnt save the validation feedback')
                        
                messages_i.append({"role": "user", "content": f"""
Please now regenerate the code function, with the aim to improve the code to achieve a lower validation error. Use the feedback where applicable. You are generating code for iteration {generation_id} out of {self.config.run.generations} total iterations. When generating code if you are unsure about something, take your best guess. You have to generate code, and cannot give an empty string answer.

Please always only fill in the following code skeleton:```
{prompts.get_skeleton_code(self.env.env_name)}
```
You cannot change the code skeleton, or input variables.
""", 'function_call': {"name": 'complete_StateDifferential_code'}})
            response_messages = self.get_llm_response_with_retries(messages_i, n=sub_group_programs) # Code message
            for response_message in response_messages:
                code_json = json.dumps(response_message["function_call"])
                if code_json not in computed_funcs:
                    function_return_message, code_dict, has_success = self.execute_function_call_with_llm_retries(response_message, self.messages)

                    generation_dict = {'function_return_message': deepcopy(function_return_message), 'iteration': generation_id, 'code_dict': deepcopy(code_dict), 'has_success': deepcopy(has_success), 'fitness': deepcopy(code_dict['val_loss']), 'code_string': code_json}
                    
                    generation_dicts.append(generation_dict)
                    history_generation.append(generation_dict)
                    computed_funcs.add(json.dumps(response_message["function_call"]))
                    if has_success:
                            break
                else:
                    generation_dict = deepcopy([generation_dict for generation_dict in history_generation if generation_dict['code_string'] == code_json][0])
                    generation_dict['iteration'] = generation_id
                    history_generation.append(generation_dict)
                    # computed_funcs.add(json.dumps(response_message["function_call"]))

                ### CODE SAVING-------------------------
                filename = f"Model_Iteration{generation_id+1}.py"
                filepath = os.path.join(self.save_dir, filename)
                with open(filepath, "w") as f:
                    f.write(generation_dict['code_dict']['StateDifferential_code'])
                self.logger.info(f"[INFO] Saved simulator code to: {filepath}")

            generation_dicts = sorted(generation_dicts, key=lambda d: d['fitness'])
            generation_dicts = [generation_dict for generation_dict in generation_dicts if not np.isnan(generation_dict['fitness'])]
            generation_dicts = generation_dicts[:self.config.run.keep_top_samples]
            mean_fitness = np.mean([generation_dict['fitness'] for generation_dict in generation_dicts])
            fitnesses = [generation_dict['fitness'] for generation_dict in generation_dicts]
            num_programs = len(generation_dicts)
            top_fitness = generation_dicts[0]['fitness']
            history_best_generation.append(deepcopy(generation_dicts[0]))
            # self.logger.info(f'[{self.name} | {self.env.env_name} | {self.env.seed}][Generation {generation_id}] | Top Fitness: {top_fitness:.3g} | Num Programs: {num_programs} | Mean Fitness: {mean_fitness} | Fitnesses: {fitnesses}')
            self.logger.info(f"[{self.name} | {self.env.env_name} | {self.env.seed}][Generation {generation_id}] | Top Fitness: {top_fitness} | Num Programs: {num_programs} | Mean Fitness: {mean_fitness} | Fitnesses: {fitnesses} | Current Gen Val Loss: {generation_dict['fitness']}")
            result = {'method': self.name, 'env_name': self.env.env_name, 'seed': self.env.seed, 'generation': generation_id, 'top_fitness': top_fitness, 'num_programs': num_programs, 'mean_fitness': mean_fitness, 'fitnesses': fitnesses, 'current_gen_val_loss': generation_dict['fitness']}
            self.logger.info(f"[{self.name} | {self.env.env_name} | {self.env.seed}][GEN RESULT] {result}")

            # Early stopping check
            if top_fitness < best_fitness:
                best_fitness = top_fitness
                patience_counter = 0  # Reset counter on improvement
            else:
                patience_counter += 1  # Increment counter if no improvement
            if patience_counter >= self.config.run.nsdt_patience:
                self.logger.info(f"Early stopping triggered at generation {generation_id}")
                break  # Exit the loop if no improvement for 'patience' generations
        
        # Evaluate best code now
        best_program = generation_dicts[0]['code_dict']
        test_loss = best_program['test_loss']
        # test_loss = self.env.evaluate_simulator_code_on_test_dataset(best_program['StateDifferential_code'], config=self.config, logger=self.logger)
        self.logger.info(f'[{self.env.env_name} {self.name} {self.env.seed}][Test Run completed successfully] MSE TEST LOSS {test_loss:.4f}')
        return test_loss
        # self.logger.info(f'[Run completed successfully] MSE VAL LOSS {self.simulator_code_dict["val_loss"]:.4f} | Iteration {iteration}/{self.max_iterations}')
        # self.logger.info('')
        # write_files_from_dict(self.simulator_code_dict, base_dir=f'{self.folder_path}/{self.name}')
        # self.save_agent_state(self.sub_messages)
        # return f'{self.folder_path}/{self.name}'


                                   
def sample_all_groups(lst, n):
    """
    Sample randomly without replacement groups of n from the list until all objects are sampled.
    The last group can be less than n.

    :param lst: The list from which to sample.
    :param n: The size of each group.
    :return: A list of sampled groups.
    """
    random.shuffle(lst)
    groups = [lst[i:i + n] for i in range(0, len(lst), n)]
    return groups