from .headers import *
from src.searchlight.bandit import MultiarmedBanditLearner
from .llm_utils.llm_api_models import LLMModel
# from .prompts.improvement_prompts import gen_specific_improvement_prompt, gen_draw_conclusions_from_feedback_prompt, gen_implement_function_from_improvement_prompt
from .prompts.prompt_generators import PromptGenerator
from src.searchlight.utils import UpdatablePriorityDictionary

import numpy as np
import os

from typing import Optional

class BeamEvolver(Evolver):
    '''
    Abstract class for evolving functions
    '''

    functions_dict: UpdatablePriorityDictionary # where values are (abstract, feedback, iteration)
    auxilary_functions_dict: list[tuple[str, dict, float]]

    def __init__(self, evaluator: Evaluator, analyzer: FeedbackAnalyzer, prompt_generator: PromptGenerator, batch_size: int = 10, seed_functions: Optional[list[tuple[str, dict]]] = None, check_function: Callable[[str], bool] = lambda x: True, parse_function: Callable[[str],str] = lambda x: x, model: Optional[LLMModel] = None, num_fittest_functions: int = 1, select_from_last_generation_only: bool = False, forward_fill: bool = False, add_old_functions_to_evaluation: bool = True, evaluate_old_functions: bool = True):
        '''
        Args:
            evaluator: evaluator for functions
            batch_size: number of functions to propose at a time
            seed_functions: list of seed functions to start with, (function, abstract)
            check_function: function to check if a function is valid
            parse_function: function to parse a function
            model: LLM model to use for generating functions
            num_fittest_functions: number of fittest functions to consider each iteration
            select_from_last_iteration_only: whether to select from the last generation only
        '''
        super().__init__()
        if model is None:
            model = GPT35Multi()
        self.evaluator = evaluator
        self.batch_size = batch_size
        self.analyzer = analyzer
        self.functions_dict = UpdatablePriorityDictionary()
        self.auxilary_functions_list = []
        self.num_evolutions = 0 # number of evolutions conducted
        self.check_function = check_function
        self.model = model
        self.prompt_generator = prompt_generator
        self.parse_function = parse_function
        self.num_fittest_functions = num_fittest_functions
        self.select_from_last_generation_only = select_from_last_generation_only
        self.forward_fill = forward_fill
        self.add_old_functions_to_evaluation = add_old_functions_to_evaluation # NOTE: we always evaluate the old functions now, but we might not add them to the auxilary functions list
        self.evaluate_old_functions = evaluate_old_functions

        # create logger
        self.logger = logging.getLogger(self.__class__.__name__)

        # if seed functions are None, generate them
        if seed_functions is None:
            prompt = self.prompt_generator.gen_seed_function_prompt_without_thought()
            seed_functions = self.generate_seed_functions(10)

        # add seed functions
        self.add_seed_functions(seed_functions)

    def generate_and_save_seed_functions(self, save_directory: str) -> None:
        '''
        Generates seed functions using generate_seed_functions and saves them in a CSV file.

        Args:
            save_directory: directory to save the CSV file
        '''
        # Generate seed functions
        seed_functions = self.generate_seed_functions()

        # Create a DataFrame to store the seed functions
        seed_functions_df = pd.DataFrame(seed_functions, columns=['function', 'abstract'])

        # Define the filename for the CSV file
        seed_functions_filename = os.path.join(save_directory, 'seed_functions.csv')

        # Save the seed functions to the CSV file
        seed_functions_df.to_csv(seed_functions_filename, index=False)

        self.logger.info(f"Seed functions saved to {seed_functions_filename}")

    @staticmethod
    def load_seed_functions(save_directory: str) -> list[tuple[str, dict]]:
        '''
        Loads seed functions from the CSV file into the format generated by generate_seed_functions.

        Args:
            save_directory: Directory where the CSV file containing seed functions is saved.

        Returns:
            List of tuples, where the first element is the function string and the second element is a dictionary
            containing the abstract.
        '''
        # Construct the file path for the seed functions CSV file
        seed_functions_filename = os.path.join(save_directory, 'seed_functions.csv')

        # Read the CSV file
        df = pd.read_csv(seed_functions_filename)

        # Extract function strings and abstracts
        functions = df['function']
        abstracts = df['abstract']

        # Convert abstracts to dictionaries
        abstract_dicts = abstracts.apply(lambda x: {'abstract': x})

        # Combine function strings and abstract dictionaries into tuples
        seed_functions = list(zip(functions, abstract_dicts))

        return seed_functions
            
    def add_seed_functions(self, seed_functions: list[tuple[str, dict]]):
        '''
        Add seed functions to the functions dictionary.

        Args:
            seed_functions: list of seed functions to add to the functions dictionary, of the form (function, notes)
        '''
        # check all seed functions using check_function first. raise error if any of them are not valid
        for function, func_notes in seed_functions:
            if not self.check_function(function):
                raise ValueError(f'Seed function {function} is not valid')

        # evaluate and add seed functions
        scores, eval_notes = self.evaluator.evaluate([function for function, func_notes in seed_functions])
        # print(scores)
        for i, (function, func_notes) in enumerate(seed_functions):
            # feedback = self.analyzer.translate(eval_notes[i])
            notes = {'feedback': eval_notes[i], 'iteration': 0, 'generation': 0, 'predecessor_function': None, 'idea_trace': []} | func_notes
            self.functions_dict.add_or_update_key(function, notes, scores[i]) # TODO: check sign of score

            # if select_from_last_generation_only is True, also add all the seed functions to the auxilary functions list
            if self.select_from_last_generation_only:
                self.auxilary_functions_list.append((function, notes, scores[i]))
                self.auxilary_functions_list = sorted(self.auxilary_functions_list, key=lambda x: x[2], reverse=True)
                self.auxilary_functions_list = self.auxilary_functions_list[:self.num_fittest_functions]
                

    def generate_seed_functions(self, num_tries: int = 10) -> list[tuple[str, dict]]:
        '''
        Generates seed functions to start the evolution using thought then function

        Args:
            num_tries: number of tries to generate a function
        '''
        # seed_functions = []
        # for _ in range(num_seed_functions):
        #     seed_function = self.generate_function(prompt = prompt, tries = 10)
        #     seed_functions.append((seed_function, dict()))
        # return seed_functions
        seed_functions = []
        for i in range(self.batch_size):
            thought_prompt = self.prompt_generator.gen_seed_thought_prompt()
            thought = self.model.generate(thought_prompt, 1)[0]
            function_prompt = self.prompt_generator.gen_seed_function_prompt_with_thought(thought)
            function_str = self.generate_function(prompt = function_prompt, tries = num_tries)
            if function_str is not None:
                seed_functions.append((function_str, {'abstract': thought}))
        return seed_functions

    def get_fittest(self, k: int = 1) -> list[tuple[str, dict, float]]:
        '''
        Returns the k fittest items (highest to lowest). If there are less than k functions, return all functions

        Items of the form (function, dict(abstract, feedback, iteration), priority)
        '''

        # get the fittest functions equal to the batch size
        if not self.select_from_last_generation_only:
            fittest_items = self.functions_dict.get_top_k_items(k)
        else:
            fittest_items = self.auxilary_functions_list
        return fittest_items
    
    def evaluate(self, functions) -> tuple[list[float], list[dict]]:
        return self.evaluator.evaluate(functions)
    
    
    def evolve_once(self):
        '''
        Conducts one cycle of evolution
        '''
        self.num_evolutions += 1

        # get the fittest functions equal to the batch size
        fittest_items = self.get_fittest(self.num_fittest_functions)

        # propose improvements for the fittest functions
        proposed_functions = []
        is_new_function = []
        func_notes = []
        counter = 0
        if len(fittest_items) == 0:
            self.logger.info('No functions in the library')
            return

        while counter < self.batch_size:
            for func, info, priority in fittest_items:
                processed_feedback = self.analyzer.translate(info['feedback'])
                generation = info['generation']

                prompt = self.prompt_generator.gen_draw_conclusions_from_feedback_prompt(func, processed_feedback)
                conclusions = self.model.generate(prompt, 1)[0]
                prompt = self.prompt_generator.gen_improved_function_prompt(prompt + conclusions)
                new_func = self.generate_function(prompt)

                if new_func is None:
                    if self.forward_fill:
                        new_func = func
                        is_new_function.append(False)
                        proposed_functions.append(new_func)
                        new_info = info
                        func_notes.append(new_info)

                else:
                    is_new_function.append(True)
                    proposed_functions.append(new_func)
                    new_info = {'iteration': self.num_evolutions, 'generation': generation + 1, 'predecessor_function': func}
                    func_notes.append(new_info)
                
                counter += 1

        # also add all fittest functions to proposed functions
        if self.evaluate_old_functions:
            for func, info, priority in fittest_items:
                proposed_functions.append(func)
                func_notes.append(info)
                is_new_function.append(False)

        # evaluate the proposed functions
        scores, notes = self.evaluate(proposed_functions)

        # add notes to func_notes
        for i in range(len(func_notes)):
            func_notes[i]['feedback'] = notes[i]


        # info = {'feedback': notes[i], 'iteration': self.num_evolutions, 'generation': proposed_generations[i], 'predecessor_function': predecessor_functions[i]}

        # now store the improved functions
        self.store_improved_functions(proposed_functions, is_new_function, scores, func_notes)

        

    def store_improved_functions(self, proposed_functions, is_new_function, scores, func_notes):
        assert len(proposed_functions) == len(is_new_function) == len(scores) == len(func_notes)

        if len(proposed_functions) == 0:
            return

        # log the following things: self.num_evolutions, proportion of non-zero scores, best score
        non_inf_scores = sum(score != float('-inf') for score in scores)
        proportion_non_inf_scores = non_inf_scores / len(scores)
        self.logger.info(f'Evolution: {self.num_evolutions}, Number of generated functions: {len(proposed_functions)}, Proportion of executable functions: {proportion_non_inf_scores}, Best score: {max(scores)}, Proportion of new functions: {sum(is_new_function) / len(is_new_function)}')

        # add the proposed functions to the dictionary
        for i, function in enumerate(proposed_functions):
            # filter out non-executable functions
            # if not is_new_function[i] or scores[i] == float('-inf'):
            if scores[i] != float('-inf') and is_new_function[i]:
                self.functions_dict.add_or_update_key(function, func_notes[i], scores[i])

        # update the auxilary functions list if select_from_last_generation_only is True with the num_fittest_functions fittest functions from proposed functions
        if self.select_from_last_generation_only:
            self.auxilary_functions_list = []
            # first add proposed functions, scores, info to the auxilary functions list
            for i, function in enumerate(proposed_functions):
                if scores[i] != float('-inf'):
                    
                    if self.add_old_functions_to_evaluation or is_new_function[i]:
                        self.auxilary_functions_list.append((function, func_notes[i], scores[i]))
                    # self.auxilary_functions_list.append((function, func_notes[i], scores[i]))
            # then sort the auxilary functions list by score
            self.auxilary_functions_list = sorted(self.auxilary_functions_list, key=lambda x: x[2], reverse=True)
            # then take the top num_fittest_functions items of the auxilary functions list
            self.auxilary_functions_list = self.auxilary_functions_list[:self.num_fittest_functions]


    def evolve(self, num_cycles: int):
        '''
        Evolves the functions for a certain number of cycles
        '''
        for _ in range(num_cycles):
            self.evolve_once()

    def generate_function(self, prompt: str, tries: int = 8) -> Optional[str]:
        '''
        Generates a function given a prompt

        TODO: generalize this to different temperature, multi-generate
        '''
        new_function = self.model.generate(prompt, 1)[0]
        # parse the function
        new_function = self.parse_function(new_function)

        # check if responses are executable
        is_executable = False
        for i in range(tries):
            try:
                # print check function

                is_executable = self.check_function(new_function)
                return new_function
            except Exception as e:
                self.logger.info(f'Error: \n {e} \n while executing function: \n --- \n  {new_function}')
                is_executable = False
                new_prompt = self.prompt_generator.gen_execution_error_feedback(prompt, new_function, e)
                new_function = self.parse_function(self.model.generate(new_prompt, 1)[0])
            # self.logger.info(f'Generated function: {new_function}')
            # i += 1
        
        return None
    
    @staticmethod
    def store_items_to_csv(items: list[dict], filename: str):
        '''
        Stores the items to a csv file
        '''
        df = pd.DataFrame(items)
        df.to_csv(filename)

    @staticmethod
    def get_items_from_csv(filename: str) -> list[dict]:
        '''
        Gets the items from a csv file
        '''
        df = pd.read_csv(filename)
        return df.to_dict(orient='records')

    def produce_analysis(self, k: int = -1, evaluator: Optional[Evaluator] = None, save_directory: str = 'outputs/', use_estimated_as_final: bool = False):
        '''
        Produces an analysis of the k fittest functions and saves the results and benchmark scores to a directory.

        Args:
            k: Number of fittest functions to analyze.
            evaluator: Evaluator object.
            save_directory: Directory to save the analysis results.
            use_estimated_as_final: Whether to use the estimated score as the final score.

        Returns:
            None
        '''

        # figure out num_calls, num_output_tokens, and num_total_tokens
        num_calls = self.model.get_num_calls()
        num_output_tokens = self.model.get_num_output_tokens()
        num_total_tokens = self.model.get_num_total_tokens()

        if evaluator is None:
            evaluator = self.evaluator
        fittest_items = self.functions_dict.get_top_k_items(k)

        # Filter out the functions that are not executable
        fittest_items = [(func, info, priority) for func, info, priority in fittest_items if priority != -float('inf')]
        
        if not use_estimated_as_final:
            # Use the evaluator to evaluate the fittest functions with benchmark
            functions = [func for func, _, priority in fittest_items]
            function_scores, function_notes, benchmark_scores = self.evaluator.evaluate_with_benchmark(functions)
        else:
            # Use the estimated scores as the final scores
            function_scores = [priority for _, _, priority in fittest_items]
            benchmark_scores = {}
        
        # Store the results in a list of dictionaries
        results = []
        for i, (func, info, priority) in enumerate(fittest_items):
            # Append info dictionary along with the final score and function and estimated score
            to_append =  {'final_score': function_scores[i], 'estimated_score': priority, 'num_calls': num_calls, 'num_output_tokens': num_output_tokens, 'num_total_tokens': num_total_tokens, 'function': func,} | info
            results.append(to_append)

        # Sort results by final score
        results = sorted(results, key=lambda x: x['final_score'], reverse=True)

        # Log the estimated scores, final score, generation, and iteration for each function
        for info in results:
            self.logger.info(f'Estimated Score: {info["estimated_score"]}, Final Score: {info["final_score"]}, Iteration: {info["iteration"]}, Generation: {info["generation"]}')

        # Log the benchmark scores
        for benchmark_name, benchmark_score in benchmark_scores.items():
            self.logger.info(f'Benchmark {benchmark_name} score: {benchmark_score}')

        # Log total number of functions
        self.logger.info(f'Total number of functions: {len(results)}')

        # Save results and benchmark scores to the save directory
        results_filename = os.path.join(save_directory, 'results.csv')
        benchmark_filename = os.path.join(save_directory, 'benchmark_scores.csv')

        pd.DataFrame(results).to_csv(results_filename, index=False)
        pd.DataFrame(benchmark_scores.items(), columns=['Benchmark', 'Score']).to_csv(benchmark_filename, index=False)

    @staticmethod
    def produce_figures(save_directory: str = 'outputs/'):
        '''
        Produces figures from the results stored in the save directory.

        Args:
            save_directory: Directory where the results are stored.

        Returns:
            None
        '''
        # Retrieve the results and benchmark scores from the save directory
        results_filename = os.path.join(save_directory, 'results.csv')
        benchmark_filename = os.path.join(save_directory, 'benchmark_scores.csv')

        results = pd.read_csv(results_filename)
        benchmark_scores = pd.read_csv(benchmark_filename).set_index('Benchmark').to_dict()['Score']

        # Create figures from the results
        # Convert the results to a pandas dataframe
        df = results

        # Create a scatter plot of the iteration (x-axis) vs the final score (y-axis)
        # Add benchmark scores as horizontal lines
        # fig = px.scatter(df, x='iteration', y='final_score', hover_data=['generation'], title='Iteration vs Final Score')
        # for benchmark_name, benchmark_score in benchmark_scores.items():
        #     fig.add_hline(y=benchmark_score, line_dash='dash', annotation_text=f'{benchmark_name} benchmark', annotation_position='top right')
        # # Change x-axis and y-axis labels to 'Iteration' and 'Final Score'
        # fig.update_layout(xaxis_title='Iteration', yaxis_title='Final Score')

        # # Create a scatter plot of the iteration (x-axis) vs the estimated score (y-axis)
        # fig2 = px.scatter(df, x='iteration', y='estimated_score', hover_data=['generation'], title='Iteration vs Estimated Score')

        # Change x-axis label to 'Iteration' and y-axis label to 'Estimated Score'
        # fig2.update_layout(xaxis_title='Iteration', yaxis_title='Estimated Score')

        # Sort the DataFrame by 'generation' for logical sequencing (if not already sorted)
        df.sort_values('generation', inplace=True)

        # Create a scatter plot of the generation (x-axis) vs the final score (y-axis)
        # Add benchmark scores as horizontal lines
        fig3 = px.scatter(df, x='generation', y='final_score', hover_data=['iteration'], title='Generation vs Final Score')
        for benchmark_name, benchmark_score in benchmark_scores.items():
            fig3.add_hline(y=benchmark_score, line_dash='dash', annotation_text=f'{benchmark_name} benchmark', annotation_position='top right')
        # Change x-axis label to 'Generation' and y-axis label to 'Final Score'
        fig3.update_layout(xaxis_title='Generation', yaxis_title='Final Score')

        # Add lines for 'function' to 'predecessor_function' connections
        for index, row in df.iterrows():
            if pd.notna(row['predecessor_function']):  # Check if the predecessor_function exists
                predecessor_row = df[df['function'] == row['predecessor_function']].iloc[0]
                # Draw a line between the current point and its predecessor
                fig3.add_shape(type='line',
                            x0=predecessor_row['generation'], y0=predecessor_row['final_score'],
                            x1=row['generation'], y1=row['final_score'],
                            line=dict(color='RoyalBlue', width=1),
                            )

        # Create a scatter plot of the generation (x-axis) vs the estimated score (y-axis)
        fig4 = px.scatter(df, x='generation', y='estimated_score', hover_data=['iteration'], title='Generation vs Estimated Score')

        # Change x-axis label to 'Generation' and y-axis label to 'Estimated Score'
        fig4.update_layout(xaxis_title='Generation', yaxis_title='Estimated Score')

        # Add lines for 'function' to 'predecessor_function' connections
        for index, row in df.iterrows():
            if pd.notna(row['predecessor_function']):  # Check if the predecessor_function exists
                predecessor_row = df[df['function'] == row['predecessor_function']].iloc[0]
                # Draw a line between the current point and its predecessor
                fig4.add_shape(type='line',
                            x0=predecessor_row['generation'], y0=predecessor_row['estimated_score'],
                            x1=row['generation'], y1=row['estimated_score'],
                            line=dict(color='RoyalBlue', width=1),
                            )

        # Save the figures to the save directory
        date_name = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
        # fig.write_html(os.path.join(save_directory, f'iteration_vs_final_score_{date_name}.html'))
        # fig2.write_html(os.path.join(save_directory, f'iteration_vs_estimated_score_{date_name}.html'))
        fig3.write_html(os.path.join(save_directory, f'generation_vs_final_score_{date_name}.html'))
        fig4.write_html(os.path.join(save_directory, f'generation_vs_estimated_score_{date_name}.html'))

        

class ThoughtBeamEvolver(BeamEvolver):
    '''
    Conducts beam evolution with thought
    '''
    # def __init__(self, evaluator: Evaluator, analyzer: FeedbackAnalyzer, prompt_generator: PromptGenerator, batch_size: int = 10, seed_functions: Optional[list[tuple[str, dict]]] = None, check_function: Callable[[str], bool] = lambda x: True, parse_function: Callable[[str],str] = lambda x: x, model: Optional[LLMModel] = None, num_fittest_functions: int = 1, select_from_last_generation_only: bool = False, forward_fill: bool = False):
        
    #     super().__init__(evaluator=evaluator, analyzer=analyzer, prompt_generator=prompt_generator, batch_size=batch_size, seed_functions=seed_functions, check_function=check_function, parse_function=parse_function, model=model, num_fittest_functions=num_fittest_functions, select_from_last_generation_only=select_from_last_generation_only, forward_fill=forward_fill)


    def evolve_once(self):
        '''
        Conducts one cycle of evolution
        '''

        self.num_evolutions += 1

        # get the fittest functions equal to the batch size
        fittest_items = self.get_fittest(self.num_fittest_functions)

        # propose improvements for the fittest functions
        proposed_functions = []
        is_new_function = []
        func_notes = []
        counter = 0
        if len(fittest_items) == 0:
            self.logger.info('No functions in the library')
            return

        while counter < self.batch_size:
            for func, info, priority in fittest_items:
                abstract = info['abstract']
                feedback = info['feedback']
                generation = info['generation']

                # first draw conclusions from feedback
                processed_feedback = self.analyzer.translate(feedback)
                prompt = self.prompt_generator.gen_draw_conclusions_from_feedback_prompt(func, processed_feedback)
                conclusions = self.model.generate(prompt, 1)[0]

                # next modify previous abstract with conclusions
                prompt = self.prompt_generator.gen_improvement_thought_prompt(prompt + conclusions, abstract)
                new_abstract = self.model.generate(prompt, 1)[0]

                # generate the new function
                prompt = self.prompt_generator.gen_improved_function_prompt(prompt + new_abstract)
                new_func = self.generate_function(prompt)

                if new_func is None:
                    if self.forward_fill:
                        new_func = func
                        is_new_function.append(False)
                        proposed_functions.append(new_func)
                        new_info = info
                        func_notes.append(new_info)

                else:
                    is_new_function.append(True)
                    proposed_functions.append(new_func)
                    new_info = {'iteration': self.num_evolutions, 'generation': generation + 1, 'predecessor_function': func, 'abstract': new_abstract}
                    func_notes.append(new_info)
                
                counter += 1


        # if select_from_last_generation_only is True, also add all fittest functions to proposed functions
        if self.evaluate_old_functions:
            for func, info, priority in fittest_items:
                proposed_functions.append(func)
                func_notes.append(info)
                is_new_function.append(False)

        # evaluate the proposed functions
        scores, notes = self.evaluate(proposed_functions)

        # add notes to func_notes
        for i in range(len(func_notes)):
            func_notes[i]['feedback'] = notes[i]

        # now store the improved functions
        self.store_improved_functions(proposed_functions, is_new_function, scores, func_notes)

class ImprovementLibraryEvolver(BeamEvolver):
    '''
    Conducts evolution with a scored library (bandit learner) of improvement ideas.

    This will build upon the base Evolver, but with an additional library of improvement ideas with scores that will be used to guide the evolution.
    '''
    mbleaner: MultiarmedBanditLearner
    evaluator: Evaluator
    model: LLMModel
    batch_size: int
    analyzer: FeedbackAnalyzer
    prompt_generator: PromptGenerator
    functions_dict: UpdatablePriorityDictionary
    num_evolutions: int


    def __init__(self, evaluator: Evaluator, model: LLMModel, 
                 analyzer: FeedbackAnalyzer, prompt_generator: PromptGenerator, batch_size: int = 10,
                 mbleaner: Optional[MultiarmedBanditLearner] = None,
                 seed_functions: Optional[list[tuple[str, dict]]] = None, check_function: Callable[[str], bool] = lambda x: True, num_ideas_per_iteration: int = 2, 
                 parse_function: Callable[[str],str] = lambda x: x, num_fittest_functions: int = 1, select_from_last_generation_only: bool = False, forward_fill: bool = False, add_old_functions_to_evaluation: bool = True, evaluate_old_functions: bool = True):
        '''
        Args:
            evaluator: evaluator to use for evaluating functions
            model: LLM model to use for generating functions
            analyzer: feedback analyzer to use for analyzing feedback
            batch_size: number of functions to sample from the function library
            mbleaner: bandit learner to use for storing improvement ideas
            seed_functions: seed functions to start the evolution
            check_function: function to check if a function is valid
            implement_steps_per_grow: number of steps to implement per grow
        '''
        super().__init__(evaluator=evaluator, model=model, analyzer=analyzer, batch_size=batch_size, seed_functions=seed_functions, check_function=check_function, prompt_generator=prompt_generator, parse_function=parse_function, num_fittest_functions=num_fittest_functions, select_from_last_generation_only=select_from_last_generation_only, forward_fill=forward_fill, add_old_functions_to_evaluation=add_old_functions_to_evaluation, evaluate_old_functions=evaluate_old_functions)
        if mbleaner is None:
            mbleaner = MultiarmedBanditLearner()
        self.mbleaner = mbleaner
        self.num_ideas_per_iteration = num_ideas_per_iteration

        # num implements per iteration should be batch_size integer divided by num_fittest_functions
        self.num_implements_per_iteration = self.batch_size // self.num_fittest_functions
        self.num_idea_loops = self.num_ideas_per_iteration // self.num_fittest_functions

        # if self.batch_size % self.num_fittest_functions is not 0, log a warning
        self.logger.info(f'Batch size: {self.batch_size}, Num fittest functions: {self.num_fittest_functions}, Num ideas per iteration: {self.num_ideas_per_iteration}, Num idea loops: {self.num_idea_loops}, Num implements per iteration: {self.num_implements_per_iteration}')
        if self.batch_size % self.num_fittest_functions != 0:
            self.logger.warning(f'Batch size {self.batch_size} is not divisible by num fittest functions {self.num_fittest_functions}')
        # if self.num_ideas_per_iteration % self.num_fittest_functions is not 0, log a warning
        if self.num_ideas_per_iteration % self.num_fittest_functions != 0:
            self.logger.warning(f'Num ideas per iteration {self.num_ideas_per_iteration} is not divisible by num fittest functions {self.num_fittest_functions}')
        

    def generate_improvement_ideas(self, batch_size, num_loops:int =1, num_ideas:int =1, improvement_prior=0.0) -> None:
        '''
        Generates improvement ideas and adds them to the bandit learner

        This is basically a reflection step where the agent reflects on the feedback and generates improvement ideas.
        
        We get the top batch_size functions from our function library along with their numerical feedback. We then pass the feedback to the feedback analyzer to sample and translate the feedback to numerical form. We then ask the LLM to reflect on the feedback and generate num_ideas improvement ideas. We then add the improvement ideas to the bandit learner.

        Args:
            batch_size: number of functions to sample from the function library
            num_loops: number of times to repeat the process
            num_ideas: number of improvement ideas to generate per prompt
        '''
        # if num_ideas > 1:
        #     raise NotImplementedError("Generating multiple improvement ideas per prompt is not yet supported")

        # get the top batch_size functions from the function library
        top_items = self.get_fittest(batch_size)

        # extract the functions and feedback from top_items
        functions = []
        unprocessed_feedback = []
        for function, info, score in top_items:
            functions.append(function)
            unprocessed_feedback.append(info['feedback'])

            # log unprocessed feedback
            # self.logger.info(f"Function {function} with feedback {info['feedback']} selected for improvement idea generation")

        # sample and translate the feedback
        # processed_feedback = [self.analyzer.translate(data) for data in unprocessed_feedback]

        improvement_ideas = []
        feedback_conclusions = []

        for function, feedback in zip(functions, unprocessed_feedback):
            for i in range(num_loops):
                # self.logger.info(f"Generating improvement ideas for function {function} with feedback {feedback}")
                # sample and translate the feedback
                processed_feedback = self.analyzer.translate(feedback)

                # draw conclusions from the feedback
                prompt = self.prompt_generator.gen_draw_conclusions_from_feedback_prompt(function, processed_feedback)
                conclusions = self.model.generate(prompt, 1)[0]
                feedback_conclusions.append(conclusions)

                # generate improvement ideas
                prompt = self.prompt_generator.gen_specific_improvement_prompt(prompt, conclusions, num_ideas=num_ideas)
                improvement_ideas_str = self.model.generate(prompt, 1)[0]
                # TODO: diversity would increase if we generated multiple improvement ideas per function

                # parse out ideas from the improvement ideas string
                improvement_ideas_list = self.prompt_generator.parse_improvement_ideas(improvement_ideas_str, num_ideas)

                # add the improvement ideas to the list
                improvement_ideas.extend(improvement_ideas_list)

        # add the improvement ideas to the bandit learner
        scores = [improvement_prior for _ in improvement_ideas]
        for idea, score, conclusion in zip(improvement_ideas, scores, feedback_conclusions):
            self.mbleaner.add_or_update(idea, None, {'feedback_conclusion': conclusion, 'iteration': self.num_evolutions, 'num_implements': 0})
    
    def implement_and_evaluate(self, num_ideas: int, num_fittest_functions: int = 1) -> None:
        '''
        Samples 1 idea from the bandit learner, applies it to the top batch_size functions, and evaluates the results.

        Adds the new functions to the function library.
        Updates the score of the idea based on how much it improved the functions.

        Args:
            batch_size: number of functions to sample from the function library
        '''

        # get num_fittest_functions fittest functions
        fittest_items = self.get_fittest(num_fittest_functions)
        
        # propose improvements for the fittest functions
        proposed_functions = []
        is_new_function = []
        func_notes = []
        new_function_to_idea = {}
        counter = 0 # counts the number of functions generated
        old_function_to_score = {func: score for func, info, score in fittest_items}
        new_function_to_old_function = {} # maps new functions to old functions

        if len(fittest_items) == 0:
            self.logger.info('No functions in the library')
            return
        
        while counter < self.batch_size:
            for function, info, prev_score in fittest_items:

                # first sample an improvement idea from the bandit learner
                # log all ideas and scores for debugging
                # for item in self.mbleaner.upd.get_items():
                #     self.logger.debug(f"Idea: {item[0]}, Score: {item[2]}")

                idea_, idea_notes_, idea_score = self.mbleaner.softmax_sample()

                # generate the new function
                prompt = self.prompt_generator.gen_implement_function_from_improvement_prompt(function, idea_)
                new_func = self.generate_function(prompt)
                
                # if new function is not executable, add old function. TODO: this is a hack
                if new_func is None:
                    if self.forward_fill:
                        new_func = function
                        is_new_function.append(True)
                        proposed_functions.append(new_func)
                        new_info = info
                        func_notes.append(new_info)
                        self.logger.info(f"Function {new_func} is not executable. Forward filling with old function.")
                else:
                    is_new_function.append(True)
                    proposed_functions.append(new_func)
                    new_info = {'abstract': '', 'iteration': self.num_evolutions, 'generation': info['generation'] + 1, 'idea_trace': info['idea_trace'] + [idea_], 'predecessor_function': function}
                    func_notes.append(new_info)
                    new_function_to_idea[new_func] = idea_
                    new_function_to_old_function[new_func] = function
                counter += 1

        # if select_from_last_generation_only is True, also add all fittest functions to proposed functions
        if self.evaluate_old_functions:
            for func, info, priority in fittest_items:
                proposed_functions.append(func)
                func_notes.append(info)
                is_new_function.append(False)

        # evaluate the new functions
        scores, notes = self.evaluate(proposed_functions)

        # assert that scores is a list of floats
        assert all(isinstance(score, float) for score in scores)


        # add notes to func_notes
        for i in range(len(func_notes)):
            func_notes[i]['feedback'] = notes[i]

        # add new functions to the function library
        self.store_improved_functions(proposed_functions, is_new_function, scores, func_notes)

        # log scores
        self.logger.debug('new_function scores: %s', scores)

        # if select_from_last_generation_only is True, we also update old_function_to_scores with the new scores
        if self.evaluate_old_functions:
            # recall that the last items in scores are the scores of the fittest functions
            for i, (function, info, score) in enumerate(fittest_items):
                # get last items from scores in reverse order
                old_function_to_score[function] = scores[i-len(fittest_items)]

        # filter non-new functions and non-executable functions from proposed functions and scores
        # proposed_functions = [function for function, is_new, score in zip(proposed_functions, is_new_function, scores) if is_new and score != float('-inf')]
        # scores = [score for score, is_new in zip(scores, is_new_function) if is_new and score != float('-inf')]
       
        # update idea score for each function
        for i, func in enumerate(proposed_functions):

            # pass if the function is not new or is not executable
            if not is_new_function[i] or scores[i] == float('-inf'):
                continue

            old_function = new_function_to_old_function[proposed_functions[i]]
            idea = new_function_to_idea[proposed_functions[i]]
            improvement_score = scores[i] - old_function_to_score[old_function]

            # get the idea notes
            idea_notes = self.mbleaner.get_notes_for_arm(idea)

            # increment the number of implementations of the idea by 1
            idea_notes['num_implements'] += 1

            # update the score of the idea
            self.mbleaner.add_or_update(idea, improvement_score, idea_notes)

            self.logger.info(f"Idea {idea} implemented with average improvement score {improvement_score}")
        

    def evolve_once(self) -> None:
        '''
        Evolves the population once
        '''
        self.num_evolutions += 1

        # generate improvement ideas
        self.generate_improvement_ideas(batch_size=self.num_fittest_functions, num_loops=self.num_idea_loops, num_ideas=self.num_ideas_per_iteration)

        # implement and evaluate
        self.implement_and_evaluate(num_fittest_functions=self.num_fittest_functions, num_ideas=self.num_implements_per_iteration)

    def produce_analysis(self, k: int = -1, evaluator: Optional[Evaluator] = None, save_directory: str = 'outputs/', use_estimated_as_final: bool = False) -> None:
        '''
        Produces an analysis of the k fittest functions and saves the idea results to a CSV file.

        Args:
            k: number of functions to analyze
            evaluator: evaluator to use for evaluating functions for final analysis
            save_directory: directory to save the analysis results
        '''
        super().produce_analysis(k=k, evaluator=evaluator, save_directory=save_directory, use_estimated_as_final=use_estimated_as_final)

        # Get the top k ideas
        top_idea_items = self.mbleaner.get_top_k_items(k)

        # Store the idea results in a dictionary of lists (to be converted to a pandas dataframe)
        # Columns: idea, score
        idea_results = []
        for idea, info, score in top_idea_items:
            value = self.mbleaner.get_value_estimate_for_arm(idea)
            to_append = info | {'idea': idea, 'score': value, 'ucb': score}
            idea_results.append(to_append)

        # Log idea results
        self.logger.info(f"Top {k} ideas:")
        for idea_result in idea_results:
            self.logger.info(f"Idea: {idea_result['idea']}, Score: {idea_result['score']}")

        # Log total number of ideas
        self.logger.info(f'Total number of ideas: {len(idea_results)}')

        # Save the idea results to a CSV file
        idea_results_df = pd.DataFrame(idea_results)
        idea_results_filename = os.path.join(save_directory, 'idea_results.csv')
        idea_results_df.to_csv(idea_results_filename, index=False)

    @staticmethod
    def produce_figures(save_directory: str = 'outputs/') -> None:
        '''
        Produces a boxplot of the idea scores and saves it.

        Args:
            save_directory: directory to save the figure
        '''
        BeamEvolver.produce_figures(save_directory=save_directory)
        idea_results_filename = os.path.join(save_directory, 'idea_results.csv')
        
        # Load the idea results from the CSV file
        idea_results_df = pd.read_csv(idea_results_filename)

        # Create a boxplot of the idea scores using plotly
        fig = px.box(idea_results_df, y='score', title='Idea Scores')

        # Save the boxplot figure
        fig_filename = os.path.join(save_directory, 'idea_scores_boxplot.html')
        fig.write_html(fig_filename)

        # self.logger.info(f"Boxplot of idea scores saved to {fig_filename}")


