from agents.base_agent import BaseAgent_openai
from agents.code_agent import CodeAgentAPI
import pandas as pd
from prompts.debug_agent_prompts import *
from utils.extractors import *
from utils.runners import execute_code
import time
import os


class DebugAgentAPI(BaseAgent_openai):
    def __init__(self, config, agent_name, scorer, main_logger, sub_logger, debug_logger, checker):
        super().__init__(
            config, agent_name, main_logger=main_logger, sub_logger=sub_logger, debug_logger=debug_logger,
            checker=checker
        )
        self.scorer = scorer
        self.coder: CodeAgentAPI | None = None
        self.test_submit_file_path = os.path.join(self.config['save_path'], 'test_submit.csv')

    def set_coder(self, coder: CodeAgentAPI):
        self.coder = coder

    @staticmethod
    def _log_debug_metrics(log_file_path: str, is_successful: bool, final_score: float | None,
                           total_submit_attempts: int, total_code_debug_attempts: int,
                           max_repeated_errors: int, final_error_msg: str | None, debug_mode: str):
        """
        Logs debug metrics to a CSV file.

        :param log_file_path: Path to the CSV log file.
        :param is_successful: True if the debugging attempt was successful, False otherwise.
        :param final_score: The final score if applicable, -1 otherwise.
        :param total_submit_attempts: Total number of submit debugging attempts.
        :param total_code_debug_attempts: Total number of code debugging attempts across all submit attempts.
        :param max_repeated_errors: The maximum count of any single identical error encountered.
        :param final_error_msg: The final error message if debugging failed, None if successful.
        :param debug_mode: The debugging mode used (e.g., "three-stage phase", "holistic").
        """
        data = {
            'debug_mode': [debug_mode],
            'is_successful': [is_successful],
            'final_score': [final_score],
            'total_submit_attempts': [total_submit_attempts + 1],
            'total_code_debug_attempts': [total_code_debug_attempts],
            'max_repeated_errors': [max_repeated_errors],
            'final_error_message': [final_error_msg if final_error_msg is not None else 'N/A']
        }
        df = pd.DataFrame(data)

        # Check if the directory exists, create if not
        os.makedirs(os.path.dirname(log_file_path), exist_ok=True)

        if not os.path.exists(log_file_path):
            df.to_csv(log_file_path, index=False)
        else:
            df.to_csv(log_file_path, mode='a', header=False, index=False)

    def execute_debug_prompt(self, prompt_type: str, clear_context: bool = False, **params):
        """
        Generate a response from the model based on the prompt

        :param prompt_type: A label indicating which specific prompt can be taken
        :param clear_context: Is it necessary to clear the context?
        :param params: parameters for substitution in the prompt
        :return: result: generated model text
        """

        prompt_mapping = {
            "localize_error": LOCALIZE_PROMPT,
            "localize_invalid": LOCALIZE_PROMPT_INVALID,
            "debug_code": DEBUG_PROMPT,
            "debug_code_invalid": DEBUG_PROMPT_INVALID,
            "merge": MERGE_PROMPT,
            "merge_invalid": MERGE_PROMPT_INVALID,
            "holistic_debug": HOLISTIC_DEBUG_PROMPT,
            "holistic_debug_invalid": HOLISTIC_DEBUG_PROMPT_INVALID
        }
        if clear_context:
            self.clear_context()
        prompt_template = prompt_mapping.get(prompt_type)
        if prompt_template is None:
            raise ValueError(f"Unknown prompt type: {prompt_type}")
        self.instructions = prompt_template.format(**params)
        result = self.generate_response('')

        if prompt_type not in (
        'merge', 'merge_invalid', "debug_code", 'debug_code_invalid', 'holistic_debug', "holistic_debug_invalid"):
            self.sub_logger.info(f"\t\tLLM Response: ...{result[-500:]}...")
        else:
            self.sub_logger.info(f"\t⬛\t{prompt_type}")
        return result

    def run_debug_process_three_stage_phase(self, code: str, code_output: str, error_story: list,
                                            try_number: int) -> str:
        """
        Debug code using a 3-stage debugging system

        :param code: Code in which the error was triggered
        :param code_output: Code output
        :param error_story: History of all errors raised during debugging
        :param try_number: Debug attempt number
        :return: Corrected complete code
        """
        self.clear_context()

        location = extract_python_code(self.execute_debug_prompt(
            "localize_error",
            code=code,
            error=error_story[-1],
            code_output=code_output
        ))
        corrected_code = extract_python_code(self.execute_debug_prompt(
            "debug_code",
            clear_context=True,
            code=code,
            error_messages=error_story,
            try_number=try_number,
            most_relevant_code_snippet=location,
            output_messages=code_output
        )
        )
        merged_code = extract_python_code(self.execute_debug_prompt(
            "merge",
            clear_context=True,
            wrong_code=code,
            most_relevant_code_snippet=location,
            code_snippet_after_correction=corrected_code,
            error=error_story[-1]
        )
        )
        return merged_code

    def run_debug_process_holistic(self, code: str, code_output: str, error_story: list, try_number: int) -> str:
        """
        Debugs code using a holistic, single-prompt approach.

        The LLM receives all context at once (full code, error history, output)
        and returns the complete, corrected code, avoiding intermediate steps
        like localization and merging.

        :param code: The entire code that produced the error.
        :param code_output: The output generated by the code before it crashed.
        :param error_story: A history of all errors encountered during debugging attempts.
        :param try_number: The current debug attempt number.
        :return: The complete, corrected code as a string.
        """
        self.clear_context()

        final_code = None
        for attempt in range(self.config['number_of_attempts_debug_generate']):
            corrected_code_response = self.execute_debug_prompt(
                "holistic_debug",
                clear_context=True,
                code=code,
                error_history=error_story[-1],
                code_output=code_output
            )
            code = extract_python_code(corrected_code_response)
            if code is None:
                self.add_message_to_context("user", INCORRECT_DEBUG_PROMPT)
            else:
                return code
            final_code = code
        return final_code

    def run_debug_process_for_invalid(self, code: str, code_output: str | None,
                                      reason: str, try_number: int, submission_path: str):
        """
        Debug code that generates an incorrect submit using a 3-stage debugging system

        :param code: Code that generates an incorrect submit
        :param code_output: Code output
        :param reason: The reason why the submission is incorrect
        :param try_number: Debug attempt number
        :param submission_path: Path to submission
        :return: Corrected complete code / error, Is the error caused by (bool)
        """
        self.clear_context()

        sample_submission_head = pd.read_csv(self.test_submit_file_path).head()
        try:
            submission_head = pd.read_csv(submission_path).head()
        except FileNotFoundError:
            submission_head = "ERROR! No submission found at path: {submission_path}"

        location = extract_python_code(self.execute_debug_prompt(
            "localize_invalid",
            code=code,
            reason=reason,
            code_output=code_output,
            sample_submission_head=sample_submission_head,
            submission_head=submission_head
        )
        )
        corrected_code = extract_python_code(self.execute_debug_prompt(
            "debug_code_invalid",
            clear_context=True,
            try_number=try_number,
            most_relevant_code_snippet=location,
            code=code,
            error_messages=reason,
            output_messages=code_output
        )
        )
        merged_code = extract_python_code(self.execute_debug_prompt(
            "merge_invalid",
            clear_context=True,
            wrong_code=code,
            most_relevant_code_snippet=location,
            code_snippet_after_correction=corrected_code
        )
        )
        return merged_code

    def run_holistic_debug_process(self, code: str, code_output: str | None,
                                   reason: str, try_number: int, submission_path: str):
        """
        Debugs code that generated an incorrect submission using a single, holistic prompt.

        :param code: The full code that produces an incorrect submission.
        :param code_output: The captured output from the code execution.
        :param reason: The reason why the submission is considered incorrect.
        :param try_number: The current debug attempt number.
        :param submission_path: The path to the generated submission file.
        :return: A tuple containing the corrected complete code and a boolean (always False here).
        """
        self.clear_context()
        sample_submission_head = pd.read_csv(self.test_submit_file_path).head()
        try:
            submission_head = pd.read_csv(submission_path).head()
        except FileNotFoundError:
            submission_head = "ERROR! No submission found at path: {submission_path}"

        final_code = None
        for attempt in range(self.config['number_of_attempts_debug_generate']):
            model_response = self.execute_debug_prompt(
                "holistic_debug_invalid",
                code=code,
                reason=reason,
                code_output=code_output,
                sample_submission_head=sample_submission_head,
                submission_head=submission_head,
                try_number=try_number
            )
            code = extract_python_code(model_response)
            if code is None:
                self.add_message_to_context("user", INCORRECT_DEBUG_PROMPT)
            else:
                return code
            final_code = code
        return final_code

    def generate_and_debug_code(self, filename, pre_path, submission_name,
                                coder_implement_func, code_params: dict, agent_name,
                                needs_invalid: bool = True, debug_speed_mode=None,
                                timeout: int | None = None ):
        number_of_iter_for_code_regeneration = int(self.config["number_of_iter_for_code_regeneration"])
        task_filepath, error, score = None, None, None
        for attempt in range(number_of_iter_for_code_regeneration):
            code = coder_implement_func(**code_params)
            if code is None:
                return None, None, 'Submission file is invalid, cannot fix', None
            self.sub_logger.info(f"✅ {agent_name.title()} successfully generated the code")

            task_filepath = self.save_code_to_file(
                code, filename, pre_path
            )

            debug_code, error, score = self.run_code(
                task_filepath,
                submission_path=os.path.join(self.config['path_log'], "submissions", submission_name),
                needs_invalid=needs_invalid,
                debug_speed_mode=debug_speed_mode,
                timeout=timeout
            )
            if error == "Same type of errors, try regenerating the code" \
                    and attempt != number_of_iter_for_code_regeneration - 1:
                self.sub_logger.info("Same type of errors, try regenerating the code")
                continue
            return task_filepath, debug_code, error, score
        return task_filepath, None, error, None

    def run_code(self, filepath: str, submission_path: str | None, needs_invalid: bool = True,
                 test_submit_file_path: str | None = None, timeout: int | None = None,
                 debug_speed_mode=None, debug_normal_speed=True):

        if timeout is None:
            timeout = self.config['runtime_error_time']
        self.agent_name = "debugger"
        start_work = time.time()
        self.sub_logger.info('⬛\tRun Debugger')
        error_story = []
        install_error_story = []
        error_title_story = []
        number_of_attempts_debug = self.config['number_of_attempts_debug']
        number_of_attempts_debug_submit = self.config["number_of_attempts_debug_submit"]
        number_of_attempts_install = self.config['number_of_attempts_install']
        error, code, original_code_description, changes_made, code_output = None, None, None, None, None
        original_filepath = filepath

        if debug_speed_mode is None:
            speed_mode = self.config['debug_speed_mode']
        else:
            speed_mode = debug_speed_mode

        if test_submit_file_path is None:
            test_submit_file_path = self.test_submit_file_path

        # --- Save and use accelerated code ---
        if speed_mode == 'fast' and debug_normal_speed:
            with open(filepath, 'r') as file:
                code = ''.join(file.readlines())
                self.sub_logger.info("⏫\tSpeed up code")
            code, original_code_description, changes_made = self.coder.speed_up_code(code, needs_invalid)
            if code is None:
                self.sub_logger.info("❌\tFail")
                return None, "The agent was unable to speed up the code", None

            filepath = f"{filepath.strip('.py')}_fast.py"
            self.sub_logger.info("✔\tSpeed up code complete!")
            self.save_code_to_file(code=code, filename=filepath, pre_path='')

        # --- Data logging variables ---
        debug_successful = False
        final_score: float | None = None
        total_code_debug_attempts_made = 0
        max_identical_error_count = 0

        if test_submit_file_path is None:
            test_submit_file_path = self.test_submit_file_path

        # Define the path for the debug log CSV
        debug_log_file_path = os.path.join(self.config['path_debug_log'], 'debug_metrics.csv')

        for submit_debug_attempt in range(number_of_attempts_debug_submit):
            self.sub_logger.info(f'⬛\tSubmit debug attempt: {submit_debug_attempt + 1}/{number_of_attempts_debug_submit}')
            any_mistakes_in_code = True
            total_code_debug_attempts_made += 1

            # For each attempt to debug the code for correct submission generation, we debug it again for errors
            for code_debug_attempt in range(number_of_attempts_debug):
                self.sub_logger.info(f'\t⬛\tCode debug attempt: {code_debug_attempt + 1}/{number_of_attempts_debug}')
                with open(filepath, 'r') as file:
                    code = ''.join(file.readlines())

                # We debug module not found errors separately
                for attempt_install in range(number_of_attempts_install):
                    code_output, error = execute_code(
                        filepath, self.config["path_debug_log"], self.agent_name, timeout=timeout
                    )
                    if error is None:
                        # No errors occurred, we are finishing debugging the code
                        any_mistakes_in_code = False
                        break

                    error_title = extract_error_title(error)
                    self.sub_logger.error(f'\t\t{error_title}')

                    if "ModuleNotFoundError" in error:
                        if install_error_story.count(error_title) > self.config['max_count_install_error']:
                            break
                        self.sub_logger.info("install module...")
                        self.install_packages(error)
                    else:
                        # If the error is of a different nature, we proceed to debugging
                        break

                if error is None:
                    # No errors occurred, we are finishing debugging the code
                    break

                if "ModuleNotFoundError" in error:
                    # We won't be able to debug this error
                    final_error_message = 'Failed to install the module'
                    self.sub_logger.info('❌\tFailed to install the module')
                    self._log_debug_metrics(
                        debug_log_file_path, is_successful=False, final_score=None,
                        total_submit_attempts=submit_debug_attempt, debug_mode=self.config["debug_mode"],
                        total_code_debug_attempts=total_code_debug_attempts_made,
                        max_repeated_errors=max_identical_error_count,
                        final_error_msg=final_error_message
                    )
                    self.write_running_time(start_work)
                    return None, final_error_message, None

                elif f"Runtime Error: exceeded {self.config['runtime_error_time']} minutes timeout" in error:
                    final_error_message = 'Runtime Error'
                    self._log_debug_metrics(
                        debug_log_file_path, is_successful=False, final_score=None,
                        total_submit_attempts=submit_debug_attempt, debug_mode=self.config["debug_mode"],
                        total_code_debug_attempts=total_code_debug_attempts_made,
                        max_repeated_errors=max_identical_error_count,
                        final_error_msg=final_error_message
                    )
                    return None, final_error_message, None

                error_title = extract_error_title(error)
                error_story.append(error)
                error_title_story.append(error_title)
                max_identical_error_count = max(max_identical_error_count, error_title_story.count(error_title))

                if error_title_story.count(error_title) >= self.config['max_count_of_identical_errors']:
                    final_error_message = 'Same type of errors, try regenerating the code'
                    self._log_debug_metrics(
                        debug_log_file_path, is_successful=False, final_score=None,
                        total_submit_attempts=submit_debug_attempt, debug_mode=self.config["debug_mode"],
                        total_code_debug_attempts=total_code_debug_attempts_made,
                        max_repeated_errors=max_identical_error_count,
                        final_error_msg=final_error_message
                    )
                    self.write_running_time(start_work)
                    return None, final_error_message, None

                match self.config["debug_mode"]:
                    case "three-stage phase":
                        code = self.run_debug_process_three_stage_phase(code, code_output, error_story, code_debug_attempt)
                    case "holistic":
                        code = self.run_debug_process_holistic(code, code_output, error_story, code_debug_attempt)
                    case _:
                        raise KeyError(f"Debug mode {self.config['debug_mode']} is not find")
                total_code_debug_attempts_made += 1

                with open(filepath, 'w') as file:
                    if not isinstance(code, str):
                        file.write('')
                    else:
                        file.write(code)

                if code_debug_attempt == number_of_attempts_debug - 1:
                    code_output, error = execute_code(
                        filepath, self.config["path_debug_log"], self.agent_name, timeout=timeout
                    )
                    if error is None:
                        any_mistakes_in_code = False
                        break

            if any_mistakes_in_code:
                # Attempts to debug the code failed
                self.sub_logger.info('❌\tCannot fix errors in code')
                final_error_message = 'Cannot fix errors in code'
                # Log data before returning
                self._log_debug_metrics(debug_log_file_path, debug_successful, final_score,
                                        submit_debug_attempt, total_code_debug_attempts_made,
                                        max_identical_error_count, final_error_message, self.config["debug_mode"])
                self.write_running_time(start_work)
                return None, final_error_message, None

            if needs_invalid:
                self.sub_logger.info("\t🛂\tChecking to see if the submit is correct.")
                result, reason, score = self.scorer.get_score(test_submit_file_path, submission_path)
                if result:
                    score_float = float(score.strip().split('\n')[-1])
                    if speed_mode == 'fast' and debug_normal_speed:
                        self.sub_logger.info(f"\t✅\tSubmit from speed code is correct! Score: {score_float}")
                        self.sub_logger.info("\t⏬\tReturn speed and start run original code")
                        code = extract_python_code(
                            self.coder.return_code_speed(
                                code=code, is_model_training_stage=needs_invalid,
                                original_code_description=original_code_description, changes_made=changes_made
                            )
                        )

                        with open(original_filepath, 'w') as file:
                            if not isinstance(code, str):
                                file.write('')
                            else:
                                file.write(code)

                        code_output, error = execute_code(
                            original_filepath, self.config["path_debug_log"], self.agent_name, timeout=timeout
                        )
                        if error is not None:
                            error_title = extract_error_title(error)
                            self.sub_logger.info(f"\t❌\tCannot fix errors in original code")
                            self.sub_logger.info(f"\t↪\tReason:\n{error_title}")
                            return None, error_title, None
                        else:
                            result, reason, score = self.scorer.get_score(test_submit_file_path, submission_path)
                            if not result:
                                self.sub_logger.error("\t❌\tSubmit from original code is incorrect!")
                                self.sub_logger.info(f"\t↪\tReason:\n{reason}")
                                return None, reason, None
                            else:
                                score_float = float(score.strip().split('\n')[-1])
                                self.sub_logger.info(f"\t✅\tSubmit is correct! Score: {score_float}")
                    else:
                        self.sub_logger.info(f"\t✅\tSubmit is correct! Score: {score_float}")
                    debug_successful = True
                    final_score = score_float
                    final_error_message = None
                    # Log data before returning
                    self._log_debug_metrics(debug_log_file_path, debug_successful, final_score,
                                            submit_debug_attempt, total_code_debug_attempts_made,
                                            max_identical_error_count, final_error_message, self.config["debug_mode"])
                    self.write_running_time(start_work)
                    return code, None, score_float

                error_title = extract_error_title(reason)
                self.sub_logger.error("\t❌\tSubmit is incorrect!")
                self.sub_logger.info(f'\t↪\tReason:\n{error_title}')
                self.sub_logger.error("\t\tDebug submit")

                match self.config["debug_mode"]:
                    case "three-stage phase":
                        code = self.run_debug_process_for_invalid(
                            code, code_output, reason,
                            submit_debug_attempt, submission_path
                        )
                    case "holistic":
                        code = self.run_holistic_debug_process(
                            code, code_output, reason,
                            submit_debug_attempt, submission_path
                        )
                    case _:
                        raise KeyError(f"Debug mode {self.config['debug_mode']} is not find")

                # We are trying to re-obtain the score after debugging the code
                # for correct submission generation
                with open(filepath, 'w') as file:
                    if not isinstance(code, str):
                        file.write('')
                    else:
                        file.write(code)
                code_output, error = execute_code(
                    filepath, self.config["path_debug_log"], self.agent_name, timeout=timeout
                )
                if error is None:
                    result, reason, score = self.scorer.get_score(test_submit_file_path, submission_path)
                    if result:
                        if speed_mode == 'fast' and debug_normal_speed:
                            self.sub_logger.info(f"\t✅\tSubmit from speed code is correct! Score: {score}")
                            self.sub_logger.info("\t⏬\tReturn speed and start run original code")
                            code = extract_python_code(self.coder.return_code_speed(
                                code=code, is_model_training_stage=needs_invalid,
                                original_code_description=original_code_description, changes_made=changes_made
                                )
                            )

                            with open(original_filepath, 'w') as file:
                                if not isinstance(code, str):
                                    file.write('')
                                else:
                                    file.write(code)
                            code_output, error = execute_code(
                                original_filepath, self.config["path_debug_log"], self.agent_name, timeout=timeout
                            )

                            if error is None:
                                result, reason, score = self.scorer.get_score(test_submit_file_path, submission_path)
                            else:
                                return None, error, None
                            if not result:
                                self.sub_logger.error("\t❌\tSubmit from original code is incorrect!")
                                self.sub_logger.info(f'\t↪\tReason:\n{reason}')
                                return None, reason, None
                        final_score = float(score.strip().split('\n')[-1])
                        self.sub_logger.info(f"\t✅\tSubmit is correct! Score: {final_score}")
                        debug_successful = True
                        final_error_message = None
                        # Log data before returning
                        self._log_debug_metrics(debug_log_file_path, debug_successful, final_score,
                                                submit_debug_attempt, total_code_debug_attempts_made,
                                                max_identical_error_count, final_error_message,
                                                self.config["debug_mode"])
                        self.write_running_time(start_work)
                        return code, None, final_score
                    else:
                        error_story.append(error)
                else:
                    error_story.append(error)

            else:
                # If there is no need to generate a submission,
                # we consider the debugging to be successfully completed
                if speed_mode == 'fast':
                    self.sub_logger.info("\t⏬\tReturn speed and start run original code")
                    code = extract_python_code(
                        self.coder.return_code_speed(
                                code=code, is_model_training_stage=needs_invalid,
                                original_code_description=original_code_description, changes_made=changes_made
                        )
                    )
                    with open(original_filepath, 'w') as file:
                        if not isinstance(code, str):
                            file.write('')
                        else:
                            file.write(code)

                    code_output, error = execute_code(original_filepath, self.config["path_debug_log"], self.agent_name,
                                                      timeout=timeout)
                    if error is not None:
                        error_title = extract_error_title(error)
                        self.sub_logger.info(f"\t❌\tDebug failed after speeding up")
                        self.sub_logger.info(f"\t↪\tReason:\n{error_title}")

                        self._log_debug_metrics(debug_log_file_path, False, None,
                                                submit_debug_attempt, total_code_debug_attempts_made,
                                                max_identical_error_count, error_title, self.config["debug_mode"])
                        self.write_running_time(start_work)
                        return None, error_title, None

                self.sub_logger.info("\t✅\tSuccess debug")
                debug_successful = True
                final_error_message = None
                # Log data before returning
                self._log_debug_metrics(debug_log_file_path, debug_successful, final_score,
                                        submit_debug_attempt, total_code_debug_attempts_made,
                                        max_identical_error_count, final_error_message, self.config["debug_mode"])
                self.write_running_time(start_work)
                return code, code_output, None

        # Attempts to debug the code failed.
        self.sub_logger.info('❌\tSubmission file is invalid, cannot fix')
        final_error_message = 'Submission file is invalid, cannot fix'
        # Log data before returning
        self._log_debug_metrics(debug_log_file_path, debug_successful, final_score,
                                submit_debug_attempt, total_code_debug_attempts_made,
                                max_identical_error_count, final_error_message, self.config["debug_mode"])
        self.write_running_time(start_work)
        return None, final_error_message, None
