"""
PSAO Prompt optimiser Implementation

This module provides an implementation of the prompt optimiser interface that
optimises prompts by adjusting the PSAO variables of different sentences.
"""

import copy
from typing import Any, Dict, List

import numpy as np
import optuna

from src.core.registry import prompt_optimiser_registry
from src.llm.base import LLMInterface
from src.prompt_optimisation.base import PromptOptimiserInterface
from src.tasks.base import TaskInterface
from src.utils import prompt_utils as Util
from src.utils.decorator_utils import with_logger


@prompt_optimiser_registry.register("psao")
class PSAOOptunaPromptOptimiser(PromptOptimiserInterface):
    """
    PSAO - Optuna prompt optimiser implementation.

    This class provides an implementation of the prompt optimiser interface that
    optimises prompts by adjusting the annotation variables of different sentences.
    """

    @with_logger
    def __init__(
        self,
        psao_intro_prompt: str,
        psao_struct_ann: str,
        r_seed: int = 42,
        optimise_user_prompt_flag: bool = False,
        optuna_db_name: str = "prompt_opt_psao_db.db",
        optuna_study_name: str = "prompt_opt_psao_study",
        optuna_n_trials: int = 5,
        **kwargs: Any,
    ):
        """
        initialise the PSAO prompt optimiser.

        Args:
            r_seed: The random seed to use for initialisation
            optimise_user_prompt_flag: Whether to optimise user prompts as well
            optuna_db_name: The name of the Optuna database
            optuna_study_name: The name of the Optuna study
            optuna_n_trials: The number of trials to run for optimisation
            **kwargs: Additional keyword arguments
        """
        # Logging initialisation parameters
        logger.info("initialising PSAOOptunaPromptOptimiser")
        logger.info(f"Random seed: {r_seed}")
        logger.info(f"Optimise user prompt flag: {optimise_user_prompt_flag}")
        logger.info(f"Optuna DB name: {optuna_db_name}")
        logger.info(f"Optuna study name: {optuna_study_name}")
        logger.info(f"Optuna n_trials: {optuna_n_trials}")

        # Setup attributes

        self.r_seed = r_seed
        self.psao_intro_prompt = psao_intro_prompt
        self.psao_struct_ann = psao_struct_ann
        self.optimise_user_prompt_flag = optimise_user_prompt_flag
        self.optuna_db_name = optuna_db_name
        self.optuna_study_name = optuna_study_name
        self.optuna_n_trials = optuna_n_trials

        # Initialise the prompt segments and structured annotations
        self.msg_template_lst = []
        self.sys_prompt_segment_lst = []
        self.sys_prompt_struct_ann_var_lst = []
        self.sys_prompt_struct_ann_var_lst_best = []

        logger.info("PSAO - optuna Promptoptimiser initialised successfully")

    @with_logger
    def _join_segment_and_annotation(self, seg_lst, ann_var_lst):
        """
        Join prompt segmentations and annotations.

        Returns:
            The joined prompt with annotation variables.
        """

        join_lst = [
            f"{p_seg} {self.psao_struct_ann}".replace("ann_var", str(ann_var))
            for p_seg, ann_var in zip(seg_lst, ann_var_lst)
        ]
        return ". ".join(join_lst)

    @with_logger
    def _create_msg_from_prompt_segment_and_annotation(
        self,
        best_flag: bool = False,
    ):
        msg_lst = []

        # System Prompt
        if best_flag:
            sys_prompt = self._join_segment_and_annotation(
                self.sys_prompt_segment_lst,
                self.sys_prompt_struct_ann_var_lst_best,
            )
        else:
            sys_prompt = self._join_segment_and_annotation(
                self.sys_prompt_segment_lst,
                self.sys_prompt_struct_ann_var_lst,
            )

        msg_lst = [{"role": "system", "content": self.psao_intro_prompt + sys_prompt}]

        for item in self.msg_template_lst[1:]:
            msg_lst.append(item)

        # TODO
        # # User Prompts

        return msg_lst

    @with_logger
    def set_message_template(
        self,
        message_template_lst: List[Dict[str, Any]],
    ) -> None:
        """
        initialise the optimiser from a message template.
        """

        self.msg_template_lst = copy.deepcopy(message_template_lst)

        # TODO: only segment the system prompt for now
        # extend this into multi turns in the future
        sys_prompt = Util.extract_system_prompt(self.msg_template_lst)
        self.sys_prompt_segment_lst = [seg for seg in sys_prompt.split(".")]

        # Initialise annotation variables randomly
        np.random.seed(self.r_seed)
        self.sys_prompt_struct_ann_var_lst = np.random.randint(
            1, 11, size=len(self.sys_prompt_segment_lst)
        )

    @with_logger
    def optimise(
        self,
        task: TaskInterface,
        llm: LLMInterface,
        **kwargs: Any,
    ) -> List[Dict[str, Any]]:
        """
        optimise a prompt using a feedback function.

        Args:
            base_prompt: The base prompt to optimise
            feedback_function: A function that takes a prompt and returns a score
            **kwargs: Additional keyword arguments for the optimisation process

        Returns:
            The optimised prompt
        """

        # Define the objective function for Optuna
        @with_logger
        def objective_func(trial):
            logger.info(f"Running trial {trial.number}")

            # Define the integer parameters for annotation variables
            params = {}
            for i in range(len(self.sys_prompt_struct_ann_var_lst)):
                param_name = f"ann_var_{i}"
                params[param_name] = trial.suggest_int(param_name, 1, 10)
                self.sys_prompt_struct_ann_var_lst[i] = params[param_name]

            logger.info(f"Trial parameters: {params}")

            # Create the prompt with the current annotation variables
            new_msg_template = self._create_msg_from_prompt_segment_and_annotation()

            # Evaluate the prompt using the feedback function
            task.update_prompt_msg_template(new_msg_template)
            _, score, error_tracker = task.run(llm)
            logger.info(f"Evaluation score: {score}")

            # Log any errors that occurred during this trial
            if error_tracker and error_tracker.get_error_count() > 0:
                logger.warning(
                    f"Trial {trial.number} had {error_tracker.get_error_count()} LLM errors"
                )

            return score

        # Create or reuse study object for knowledge transfer
        study = optuna.create_study(
            study_name=self.optuna_study_name,
            direction="maximize",
            storage=f"sqlite:///{self.optuna_db_name}",
            load_if_exists=True,
        )

        # Log current study state
        try:
            if hasattr(study, "trials") and study.trials:
                logger.debug(
                    f"Study has {len(study.trials)} existing trials from previous runs"
                )
                if study.best_trial:
                    try:
                        logger.debug(
                            f"Current best trial value: {study.best_value:.4f}"
                        )
                    except (TypeError, AttributeError):
                        # Handle case where study.best_value is a mock or not a number
                        logger.debug(f"Current best trial value: {study.best_value}")
            else:
                logger.debug("No existing trials found - starting fresh study")
        except Exception as e:
            logger.debug(
                f"Could not access study state (this is normal for new studies): {str(e)}"
            )

        logger.info(f"Starting optimisation with {self.optuna_n_trials} trials")
        study.optimize(objective_func, n_trials=self.optuna_n_trials)
        logger.info("Optimisation complete")
        logger.info(f"Best trial: {study.best_trial.number}")
        logger.info(f"Best value: {study.best_value}")
        logger.info(f"Best params: {study.best_params}")

        # Store the best parameters
        logger.info("Storing best parameters")
        self.sys_prompt_struct_ann_var_lst_best = []
        for i in range(len(self.sys_prompt_struct_ann_var_lst)):
            param_name = f"ann_var_{i}"
            if param_name in study.best_params:
                self.sys_prompt_struct_ann_var_lst_best.append(
                    study.best_params[param_name]
                )
                logger.info(
                    f"Using best param for {param_name}: {study.best_params[param_name]}"
                )
            else:
                logger.info("Error!!!!")

        # Log optimisation statistics
        self._log_optimisation_stats(study)

    @with_logger
    def apply(
        self,
        **kwargs: Any,
    ) -> List[Dict[str, Any]]:
        """
        Apply the optimised prompt transformation to a base prompt.

        Args:
            **kwargs: Additional keyword arguments for the transformation process

        Returns:
            The transformed prompt message list
        """

        # Create the prompt with the best annotation variables
        result = self._create_msg_from_prompt_segment_and_annotation(best_flag=True)

        return result

    @with_logger
    def _log_optimisation_stats(self, study) -> None:
        """Log detailed optimisation statistics."""
        logger.info("=== optimisation Statistics ===")
        logger.info(f"Total trials completed: {len(study.trials)}")
        logger.info(f"Best trial number: {study.best_trial.number}")

        # Handle potential mock objects in tests
        try:
            logger.info(f"Best value: {study.best_value:.4f}")
        except (TypeError, AttributeError):
            logger.info(f"Best value: {study.best_value}")

        # Log parameter distribution
        if study.best_params:
            logger.info("Best parameters:")
            for param_name, param_value in study.best_params.items():
                logger.info(f"  {param_name}: {param_value}")

        logger.info("=== End Statistics ===")
