from agents.base_agent import BaseAgent_openai
from agents.debug_agent import DebugAgentAPI
from agents.code_agent import CodeAgentAPI
from prompts.estimate_agent_prompts import *
import os
from utils.extractors import *


class EstimateAgent(BaseAgent_openai):
    def __init__(self, config, agent_name, main_logger, sub_logger, debug_logger,
                 debugger: DebugAgentAPI, coder: CodeAgentAPI, background_data):
        super().__init__(config, agent_name, debug_logger, main_logger, sub_logger)
        self.debug_agent = debugger
        self.coder = coder
        self.background_data = background_data

    def generate_description(self, code):
        """
        Generates a description for the code to make it easier to predict the score.

        :param code: the code for which we generate the description
        """

        self.clear_context()

        # setup instruction and answer generation
        self.instructions = GENERATE_DESRIPTION_PROMPT.format(code=code)
        description = self.generate_response()

        return description

    def predict_score_and_debug(self, examples_other: list, example_this: dict, description: str,
                                code: str, node_index: str | int, new_filename: str, pre_path: str):
        """
        Predicts the validation score for the model training code
        and debugs it on a speeded up version (with one epoch/iteration, etc.)

        :param examples_other: description examples with score trained with another feature engineering
        :param example_this: description example with score trained on this dataset version
        :param description: description for actual training code
        :param code: code for which we predict score
        :param node_index: index of the node in the tree that is linked to the current code
        :param new_filename: filename for saving current code
        :param pre_path: prefix to the file storage path
        """

        self.clear_context()
        self.sub_logger.info(f"Predict score...")
        # Speed up the code by reducing the number of training epochs\iterations\etc. to 1
        self.sub_logger.info(f"\t1. Speed up code...")
        modified_code, original_code_description, changes_made = self.coder.speed_up_code(code, is_model_training_stage=True)
        if modified_code is None:
            return None, None

        # Debug speeded up code
        self.sub_logger.info(f"Predict score...")
        task_filepath = self.save_code_to_file(modified_code, new_filename, pre_path)
        self.sub_logger.info(f"\t2. Debug...")
        debug_code, code_output, score = self.debug_agent.run_code(
            task_filepath,
            submission_path=os.path.join(self.config['path_log'], pre_path, f"my_submission_{node_index}.csv"),
            needs_invalid=True,
            debug_normal_speed=False
        )

        example_this['code_output'] = code_output

        # If cannot debug code just return None
        if debug_code is None:
            return None, None

        # Return the number of epochs to the debugged code
        self.sub_logger.info(f"\t3. Return speed...")
        code = extract_python_code(self.coder.return_code_speed(debug_code, original_code_description, changes_made))
        _ = self.save_code_to_file(code, new_filename, pre_path)

        # Score prediction
        self.sub_logger.info(f"\t4. Predict...")
        if examples_other is not None:
            self.instructions = PREDICT_SCORE_PROMPT_WITH_ANCHOR_EXAMPLES.format(
                examples_other=examples_other,
                example_this=example_this,
                test_case=description,
                test_code_output=code_output,
                task_description=self.background_data
            )
        else:
            self.instructions = PREDICT_SCORE_PROMPT_WITHOUT_ANCHOR_EXAMPLES.format(
                test_case=description,
                test_code_output=code_output,
                task_description=self.background_data
            )
        try:
            json_with_score = self.generate_response()
            json_score = extract_json(json_with_score)
            score = json_score['score']
            return code, float(score)
        except Exception as err:
            self.main_logger.error(f'Error while extracting score: {err}')
            return None, None
