from agents.base_agent import BaseAgent_openai
from prompts.code_agent_prompts import *
from utils.extractors import *
from utils.device_info import get_system_info
import os


class CodeAgentAPI(BaseAgent_openai):
    def __init__(self, config, agent_name, debug_logger, main_logger, sub_logger, checker, is_higher_better):
        super().__init__(config, agent_name, debug_logger, main_logger, sub_logger, checker)
        self.data_dir_path = self.config['save_path']
        self.result_dir_path = self.config['path_log']
        self.background_data_path = self.config['background_data_path']
        self.is_higher_better = is_higher_better

    def implement_task(self, previous_code: list | None, task_name: str,
                       idea: str, eda_output: str | None, eda_images: list | None,
                       node_index: int | str, parent_index: int | str) -> str:
        """
        implement the idea in Python

        :param previous_code: code of ideas preceding the implementation of this idea in this branch
        :param task_name: stage name
        :param idea: description of the idea to be implemented
        :param eda_output: text output from the EDA stage
        :param eda_images: array of images in base64 format
        :param node_index: index of nodes for generating file names to make them unique
        :param parent_index: index of parents nodes for generating file names to make them unique
        :return: generated code
        """
        with open(self.background_data_path, 'r') as f:
            background_info = '\n'.join(f.readlines())

        self.clear_context()
        match task_name:
            case "Model training":
                self.instructions = MODELING_TASK_PROMPT.format(
                    background_info=background_info,
                    previous_code=previous_code,
                    data_dir_path=self.result_dir_path,
                    idea=idea,
                    eda_output=eda_output,
                    node_index=node_index,
                    parent_index=parent_index,
                    device_info=get_system_info(),
                    sample_submission_path=os.path.join(self.config['path_log'], "code", "data",
                                                        "sample_submission.csv")
                )
            case "Data preparation and feature engineering":
                self.instructions = FE_TASK_PROMPT.format(
                    background_info=background_info,
                    data_dir_path=self.data_dir_path,
                    result_dir_path=self.result_dir_path,
                    idea=idea,
                    eda_output=eda_output,
                    node_index=node_index,
                    device_info=get_system_info()
                )
        code = self.code_generation_from_several_attempts(
            number_of_attempts=self.config['number_of_attempts_coder'],
            need_extract_code=True,
            input_text='',
            task_name=task_name,
            images=eda_images,
            node_index=node_index,
            parent_index=parent_index,
            idea=idea
        )
        return code

    def implement_eda(self, ideas, result_dir_path):
        """
        Implement EDA in a single python code

        :param ideas: ideas that need to be implemented
        :param result_dir_path: path to the folder where you want to save the images obtained when running the code
        :return: code/None, error/None
        """
        with open(self.background_data_path, 'r') as f:
            background_info = '\n'.join(f.readlines())

        self.clear_context()
        self.instructions = EDA_TASK_PROMPT.format(
            background_info=background_info,
            data_dir_path=self.data_dir_path,
            result_dir_path=result_dir_path,
            idea=ideas
        )
        code = self.generate_response('')
        return extract_python_code(code)

    def merge_insights(self, data_from_the_first_idea,
                       data_from_the_second_idea, current_task: str,
                       node_index: int, parent_index: int | None
                       ) -> str:
        self.clear_context()

        idea1, code1, score1 = data_from_the_first_idea
        idea2, code2, score2 = data_from_the_second_idea

        if self.is_higher_better:
            score_mode_text = "A higher score indicates higher quality"
        else:
            score_mode_text = "A lower score indicates higher quality"
        match current_task:
            case "Model training":
                self.instructions = MERGE_IDEAS_CODE_MODELING.format(
                    idea1=idea1, idea2=idea2,
                    code1=code1, code2=code2,
                    score1=score1, score2=score2,
                    data_dir_path=self.result_dir_path,
                    device_info=get_system_info(),
                    node_index=node_index,
                    parent_index=parent_index,
                    score_mode=score_mode_text
                )
            case "Data preparation and feature engineering":
                self.instructions = MERGE_IDEAS_CODE_FE.format(
                    idea1=idea1, idea2=idea2,
                    code1=code1, code2=code2,
                    score1=score1, score2=score2,
                    data_dir_path=self.data_dir_path,
                    result_dir_path=self.result_dir_path,
                    device_info=get_system_info(),
                    node_index=node_index,
                    score_mode=score_mode_text
                )
        result = self.generate_response('')
        return extract_python_code(result)

    def format_code(self, code: str) -> str:
        """
        Format for easier reading (move all imports to the beginning, structure, etc.)

        :param code: code that needs to be formatted
        :return:
        """
        self.clear_context()
        self.instructions = FORMAT_CODE.format(code=code)
        result = extract_python_code(self.generate_response(''))
        return result

    def replace_paths(self, code: str, train_csv_path: str, test_csv_path: str, output_dir: str,
                      submission_filename: str, base_dir: str) -> str:
        """
        Replace paths for reading and saving data with specified ones

        :param code: The code in which the replacement must be made
        :param train_csv_path: path to the training sample
        :param test_csv_path: path to the test sample
        :param output_dir: path to the output folder
        :param submission_filename: submission name
        :param base_dir: the base directory from which all files will be taken
        :return: modified code
        """
        self.clear_context()
        self.instructions = REPLACE_DIR_FOR_TEST_PROMPT.format(
            train_csv_path=train_csv_path,
            test_csv_path=test_csv_path,
            output_dir=output_dir,
            submission_filename=submission_filename,
            base_dir=base_dir
        )
        result = self.generate_response(f"# Make changes to the paths to the data in this code:\n\n{code}")
        return result

    def final_processing(self, code: str, train_csv_path: str, test_csv_path: str, output_dir: str,
                         submission_filename: str, base_dir: str, idea: str):
        self.clear_context()
        self.instructions = FINAL_PREPROCESSING.format(
            train_csv_path=train_csv_path,
            test_csv_path=test_csv_path,
            output_dir=output_dir,
            submission_filename=submission_filename,
            base_dir=base_dir,
            idea=idea,
            device_info=get_system_info()
        )
        result = self.generate_response(f"# Make changes to the paths to the data in this code:\n\n{code}")
        return result

    def speed_up_code(self, code: str, is_model_training_stage: bool):
        """
        Change the code so that it runs faster (reduce the number of epochs, remove hyperparameter selection).
        Necessary for quick debugging

        :param code: code that needs to be accelerated
        :return: json with accelerated code and the number of epochs that were in the original code
        """
        for _ in range(3):
            self.clear_context()
            if is_model_training_stage:
                self.instructions = SPEED_UP_CODE_PROMPT_MT.format(code=code)
            else:
                self.instructions = SPEED_UP_CODE_PROMPT_FE.format(code=code)
            result = self.generate_response('')
            result_json = extract_json(result)
            result_code = extract_python_code(result)

            if result_json is not None and 'original_code_description' in result_json and \
                    'changes_made' in result_json and result_code is not None:
                return result_code, result_json['original_code_description'], result_json["changes_made"]
        return None, None, None

    def return_code_speed(self, code: str, original_code_description, changes_made,
                          is_model_training_stage: bool = True) -> str:
        """
        Restore the code's original speed by changing the number of epochs
        in the accelerated code to the original number

        :param code: Accelerated code
        :param epochs: Number of epochs in the source code
        :return: code with corrected number of epochs
        """
        self.clear_context()
        if is_model_training_stage:
            self.instructions = RETURN_CODE_SPEED_MT.format(code=code,
                                                            original_code_description=original_code_description,
                                                            changes_made=changes_made, devices_info=get_system_info())
        else:
            self.instructions = RETURN_CODE_SPEED_FE.format(code=code,
                                                            original_code_description=original_code_description,
                                                            changes_made=changes_made, devices_info=get_system_info())
        result = self.generate_response('')
        return result
