from typing import Optional

from logger import logger

from .episode_types import HonestResult
from .honest_models import HonestModel
from .paraphrasers import Paraphraser
from .tasks import TaskInstance


class HonestEpisode:
    """
    Episode class is an instance of the environment where the agents interact.
    It takes an encoder, a decoder, and a paraphraser as input and runs an episode.
    """

    def __init__(
        self,
        honest_model: HonestModel,
        paraphraser: Optional[Paraphraser] = None,
    ):
        """
        Initialize the Episode class with the encoder and decoder.
        Args:
            encoder: The encoder agent.
            decoder: The decoder agent.
            paraphraser: The paraphraser.
            clean_run: Indicates whether this is a clean episode without steganography, or whether there is an attempt to encode a payload in the output. If True, the decoder will not be run.
        """
        self.honest_model = honest_model
        self.paraphraser = paraphraser

    def run(self, task: TaskInstance) -> HonestResult:
        """
        Run the episode with the given task and payload.

        Args:

            task: The task to be performed by the encoder.
                Example: QATaskSimple("What colour is Santa Claus?", "red")
        """
        logger.debug(f"Task was: {task}")
        message = self.honest_model.get_message(task)
        message_paraphrased = (
            self.paraphraser.paraphrase(message, task)
            if self.paraphraser is not None
            else None
        )

        return HonestResult(
            (message_eval := task.evaluate(message)),
            task.evaluate(message_paraphrased)
            if message_paraphrased is not None
            else message_eval,
            message,
            message_paraphrased,
        )
